mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
Compare commits
194 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
8f8819795f | ||
|
95ccba3705 | ||
|
b400c275e4 | ||
|
84ab88d843 | ||
|
4645678ff0 | ||
|
ad765cd06b | ||
|
16b4b7974a | ||
|
459fbdebe3 | ||
|
449cee49ca | ||
|
73e797528d | ||
|
fe56f760df | ||
|
d62c941c56 | ||
|
9a8d0462e1 | ||
|
5861da1ad7 | ||
|
0b28aabb94 | ||
|
24bec29ffc | ||
|
37104acd75 | ||
|
87a0af4ec2 | ||
|
9c26b52940 | ||
|
d23b385eee | ||
|
d9bb9bebc9 | ||
|
3d059f91ab | ||
|
0142550096 | ||
|
f5f14dc660 | ||
|
54d15462dc | ||
|
2e60a8dd65 | ||
|
e5503eba78 | ||
|
e497bc09f6 | ||
|
67ce543e04 | ||
|
83fe45c15e | ||
|
11f2eec10e | ||
|
a35fbdb925 | ||
|
8c2c348f3c | ||
|
095775e05c | ||
|
0b3e3db043 | ||
|
f91434e99b | ||
|
8b91f92978 | ||
|
27ed848676 | ||
|
83ef364177 | ||
|
83b7b7bb92 | ||
|
c73ae0bd88 | ||
|
d4c6faa67b | ||
|
4ac06ddf56 | ||
|
f01dc9e743 | ||
|
5c5528e362 | ||
|
ed46c2c414 | ||
|
f74c36fe0d | ||
|
ae4451c3da | ||
|
b447f7e821 | ||
|
094975c3a8 | ||
|
dc5f05f8e6 | ||
|
124398fa57 | ||
|
c5ecc7a4de | ||
|
cae0cbe87d | ||
|
bbe218a4f7 | ||
|
58a65f7914 | ||
|
976eae216f | ||
|
622908deab | ||
|
55a6618434 | ||
|
036d802b62 | ||
|
8e92942a18 | ||
|
3208d1cd1d | ||
|
cdf70d6a28 | ||
|
ab9dafc68f | ||
|
31766dad77 | ||
|
ec35976f82 | ||
|
cb42b3ad83 | ||
|
491ed9e11d | ||
|
144d99c147 | ||
|
08bbfa16a1 | ||
|
d8ff7f2623 | ||
|
e88f6f6ee9 | ||
|
fa4e9511f8 | ||
|
a914a21899 | ||
|
aad9c2b0bd | ||
|
1f35cc7a31 | ||
|
683ff53fa3 | ||
|
5eec3a8bb6 | ||
|
b0069e0485 | ||
|
d7a24c03cf | ||
|
cea9dbc971 | ||
|
c00add9c03 | ||
|
97c5f7e685 | ||
|
1cae3197c4 | ||
|
3498f6085e | ||
|
142a49a80d | ||
|
06dfe9abfe | ||
|
ed96ba6503 | ||
|
feaa2477b7 | ||
|
230aa25641 | ||
|
9c89d0070e | ||
|
fde3234cbc | ||
|
d6a0c67e2f | ||
|
a7448661f7 | ||
|
5543fdc765 | ||
|
b8a4928d0e | ||
|
8a1cfd6122 | ||
|
794ec58b75 | ||
|
f0ed76583c | ||
|
cfd4fbb479 | ||
|
6df0fc0b55 | ||
|
d6881c37ab | ||
|
8a211dc7fc | ||
|
4cccce4b44 | ||
|
76bcb4948d | ||
|
b86c3947ab | ||
|
8a870b31b9 | ||
|
571ac9b507 | ||
|
4b8cda684b | ||
|
856709d5c3 | ||
|
36223f834e | ||
|
0ef8c8a97a | ||
|
c1cf36c0dc | ||
|
dd2bd5fdb3 | ||
|
88fd56f549 | ||
|
e3f2018cb5 | ||
|
bb69c5b199 | ||
|
c9d68945cc | ||
|
c07a2cc82b | ||
|
065aabb13d | ||
|
cb747b33da | ||
|
80e7d98f88 | ||
|
ee0dffcd14 | ||
|
4ef2e045c9 | ||
|
73b7cf83f6 | ||
|
eb3df0f46f | ||
|
c690da5973 | ||
|
db922eb77e | ||
|
40b00275b2 | ||
|
6cb41a80a1 | ||
|
d2ff68e98d | ||
|
d9dda11726 | ||
|
d937eb64da | ||
|
18c4607d46 | ||
|
29a0893b67 | ||
|
0a89902663 | ||
|
4e172028aa | ||
|
6ab02931cf | ||
|
cc212154e0 | ||
|
1dd346666a | ||
|
1d3c9beba8 | ||
|
2dfe3b3ee6 | ||
|
64a33c1f05 | ||
|
bdb3e488e4 | ||
|
17367438f3 | ||
|
b980848abf | ||
|
447a5b2f87 | ||
|
630f198624 | ||
|
8f6146f11a | ||
|
eecca27113 | ||
|
6e982f43a1 | ||
|
c20025dbf7 | ||
|
de19e7e844 | ||
|
d61f14f271 | ||
|
885144166f | ||
|
82f6ea1b71 | ||
|
5f78ec32a5 | ||
|
922cc38fbc | ||
|
120bd3e3bb | ||
|
1470aec9d9 | ||
|
203cade244 | ||
|
46994b34fb | ||
|
dc9b8e9814 | ||
|
3c7ae48f7f | ||
|
cc8b9650bd | ||
|
e07acc7f68 | ||
|
880ab9c2f3 | ||
|
1660154ae6 | ||
|
2e22164d4a | ||
|
83624a07be | ||
|
01067f8ba8 | ||
|
4f7e00f4ce | ||
|
da5ab46705 | ||
|
a9c7d2e3b6 | ||
|
afb6c728d8 | ||
|
d37a43e581 | ||
|
23bc38b10d | ||
|
ab5f616920 | ||
|
8f66d323d0 | ||
|
7eeefa3b57 | ||
|
a72f339c79 | ||
|
11ab329883 | ||
|
6f0b8c947d | ||
|
1708865fdc | ||
|
ea7f4082c4 | ||
|
3bb3fd19ae | ||
|
bf59118a93 | ||
|
c3bd7212c2 | ||
|
f01f2fb6e7 | ||
|
07b01293c5 | ||
|
cc66dccbe8 | ||
|
82c24f7420 | ||
|
a2d878fa0f | ||
|
b2fac5d947 |
169
.github/workflows/build.yaml
vendored
169
.github/workflows/build.yaml
vendored
@ -6,10 +6,11 @@ on:
|
|||||||
hardware:
|
hardware:
|
||||||
type: string
|
type: string
|
||||||
description: Hardware
|
description: Hardware
|
||||||
# options:
|
# options:
|
||||||
# - cuda
|
# - cuda
|
||||||
# - rocm
|
# - cuda-trtllm
|
||||||
# - intel
|
# - rocm
|
||||||
|
# - intel
|
||||||
required: true
|
required: true
|
||||||
release-tests:
|
release-tests:
|
||||||
description: "Run release integration tests"
|
description: "Run release integration tests"
|
||||||
@ -24,22 +25,34 @@ jobs:
|
|||||||
docker_volume: ${{ steps.final.outputs.docker_volume }}
|
docker_volume: ${{ steps.final.outputs.docker_volume }}
|
||||||
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
||||||
runs_on: ${{ steps.final.outputs.runs_on }}
|
runs_on: ${{ steps.final.outputs.runs_on }}
|
||||||
label: ${{ steps.final.outputs.label }}
|
label_extension: ${{ steps.final.outputs.label_extension }}
|
||||||
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
|
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-highmemory-32-plus-priv
|
group: aws-highmemory-64-plus-priv
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
|
id-token: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Inject slug/short variables
|
- name: Inject slug/short variables
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
- name: Construct harware variables
|
- name: Inject required variables for sccache to interact with Github Actions Cache
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
|
||||||
|
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||||
|
|
||||||
|
- name: Extract TensorRT-LLM version
|
||||||
|
run: |
|
||||||
|
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
|
||||||
|
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
|
||||||
|
- name: Construct hardware variables
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
case ${{ inputs.hardware }} in
|
case ${{ inputs.hardware }} in
|
||||||
@ -51,15 +64,34 @@ jobs:
|
|||||||
export runs_on="aws-g6-12xl-plus-priv-cache"
|
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||||
export platform=""
|
export platform=""
|
||||||
export extra_pytest=""
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
cuda-trtllm)
|
||||||
|
export dockerfile="Dockerfile_trtllm"
|
||||||
|
export label_extension="-trtllm"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export docker_devices=""
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
|
export platform=""
|
||||||
|
export extra_pytest=""
|
||||||
|
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
|
||||||
|
export build_type="release";
|
||||||
|
export target="";
|
||||||
|
else
|
||||||
|
export build_type="dev";
|
||||||
|
export target="ci-runtime";
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
export dockerfile="Dockerfile_amd"
|
export dockerfile="Dockerfile_amd"
|
||||||
export label_extension="-rocm"
|
export label_extension="-rocm"
|
||||||
export docker_devices="/dev/kfd,/dev/dri"
|
export docker_devices="/dev/kfd,/dev/dri"
|
||||||
export docker_volume="/mnt"
|
export docker_volume="/mnt"
|
||||||
export runs_on="amd-gpu-runners"
|
# This runner was deactivated.
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
export platform=""
|
export platform=""
|
||||||
export extra_pytest="-k test_flash_gemma_gptq_load"
|
export extra_pytest="-k test_flash_gemma_gptq_load"
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
intel-xpu)
|
intel-xpu)
|
||||||
export dockerfile="Dockerfile_intel"
|
export dockerfile="Dockerfile_intel"
|
||||||
@ -69,6 +101,7 @@ jobs:
|
|||||||
export runs_on="ubuntu-latest"
|
export runs_on="ubuntu-latest"
|
||||||
export platform="xpu"
|
export platform="xpu"
|
||||||
export extra_pytest=""
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
intel-cpu)
|
intel-cpu)
|
||||||
export dockerfile="Dockerfile_intel"
|
export dockerfile="Dockerfile_intel"
|
||||||
@ -79,7 +112,27 @@ jobs:
|
|||||||
export runs_on="aws-highmemory-32-plus-priv"
|
export runs_on="aws-highmemory-32-plus-priv"
|
||||||
export platform="cpu"
|
export platform="cpu"
|
||||||
export extra_pytest="-k test_flash_gemma_simple"
|
export extra_pytest="-k test_flash_gemma_simple"
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
|
neuron)
|
||||||
|
export dockerfile="Dockerfile.neuron"
|
||||||
|
export label_extension="-neuron"
|
||||||
|
export docker_devices="/dev/neuron0"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export runs_on="aws-inf2-8xlarge"
|
||||||
|
export platform="cpu"
|
||||||
|
export extra_pytest="--neuron"
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
gaudi)
|
||||||
|
export dockerfile="Dockerfile_gaudi"
|
||||||
|
export label_extension="-gaudi"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export docker_devices=""
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
|
export platform=""
|
||||||
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
esac
|
esac
|
||||||
echo $dockerfile
|
echo $dockerfile
|
||||||
echo "Dockerfile=${dockerfile}"
|
echo "Dockerfile=${dockerfile}"
|
||||||
@ -88,19 +141,22 @@ jobs:
|
|||||||
echo $runs_on
|
echo $runs_on
|
||||||
echo $platform
|
echo $platform
|
||||||
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
||||||
echo "LABEL=${label_extension}" >> $GITHUB_ENV
|
echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
|
||||||
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
||||||
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
|
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
|
||||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||||
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
||||||
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||||
|
echo "TARGET=${target}" >> $GITHUB_ENV
|
||||||
|
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
|
||||||
- name: Initialize Docker Buildx
|
- name: Initialize Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
with:
|
||||||
install: true
|
install: true
|
||||||
buildkitd-config: /tmp/buildkitd.toml
|
buildkitd-config: /tmp/buildkitd.toml
|
||||||
- name: Login to internal Container Registry
|
- name: Login to internal Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.REGISTRY_USERNAME }}
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
@ -113,6 +169,12 @@ jobs:
|
|||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Login to Docker Hub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: docker.io
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
- name: Login to Azure Container Registry
|
- name: Login to Azure Container Registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
@ -127,9 +189,9 @@ jobs:
|
|||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
docker.io/huggingface/text-generation-inference-ci
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
# If main, release or tag
|
# If main, release or tag
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
if: ${{ github.event_name != 'pull_request' }}
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
@ -137,16 +199,16 @@ jobs:
|
|||||||
uses: docker/metadata-action@v4.3.0
|
uses: docker/metadata-action@v4.3.0
|
||||||
with:
|
with:
|
||||||
flavor: |
|
flavor: |
|
||||||
latest=auto
|
latest=false
|
||||||
images: |
|
images: |
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
ghcr.io/huggingface/text-generation-inference
|
ghcr.io/huggingface/text-generation-inference
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=semver,pattern={{version}}${{ env.LABEL }}
|
type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
|
||||||
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }}
|
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
|
||||||
type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
id: build-and-push
|
id: build-and-push
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
@ -157,27 +219,66 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
PLATFORM=${{ env.PLATFORM }}
|
PLATFORM=${{ env.PLATFORM }}
|
||||||
|
build_type=${{ env.BUILD_TYPE }}
|
||||||
|
sccache_gha_enabled=on
|
||||||
|
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
|
||||||
|
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
||||||
|
target: ${{ env.TARGET }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
|
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
|
||||||
integration_tests:
|
precompile_neuron_models:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: build-and-push
|
needs: build-and-push
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.label_extension == '-neuron'
|
||||||
|
runs-on:
|
||||||
|
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||||
|
env:
|
||||||
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install
|
||||||
|
run: |
|
||||||
|
make install-integration-tests
|
||||||
|
- name: Export neuron models
|
||||||
|
run: |
|
||||||
|
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||||
|
echo $DOCKER_IMAGE
|
||||||
|
docker pull $DOCKER_IMAGE
|
||||||
|
export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
|
||||||
|
python integration-tests/fixtures/neuron/export_models.py
|
||||||
|
integration_tests:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs: [precompile_neuron_models, build-and-push]
|
||||||
|
if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
|
||||||
runs-on:
|
runs-on:
|
||||||
group: ${{ needs.build-and-push.outputs.runs_on }}
|
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||||
env:
|
env:
|
||||||
@ -204,3 +305,23 @@ jobs:
|
|||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
docker pull $DOCKER_IMAGE
|
docker pull $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
||||||
|
|
||||||
|
backend_trtllm_cxx_tests:
|
||||||
|
needs: build-and-push
|
||||||
|
if: needs.build-and-push.outputs.label_extension == '-trtllm'
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
runs-on:
|
||||||
|
group: aws-g6-12xl-plus-priv-cache
|
||||||
|
container:
|
||||||
|
image: ${{ needs.build-and-push.outputs.docker_image }}
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
options: --gpus all --shm-size=8g
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Run C++/CUDA tests
|
||||||
|
if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
|
||||||
|
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
|
||||||
|
5
.github/workflows/ci_build.yaml
vendored
5
.github/workflows/ci_build.yaml
vendored
@ -20,6 +20,8 @@ on:
|
|||||||
- "Dockerfile"
|
- "Dockerfile"
|
||||||
- "Dockerfile_amd"
|
- "Dockerfile_amd"
|
||||||
- "Dockerfile_intel"
|
- "Dockerfile_intel"
|
||||||
|
- "Dockerfile.neuron"
|
||||||
|
- "Dockerfile_gaudi"
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@ -37,11 +39,12 @@ jobs:
|
|||||||
# fail-fast is true by default
|
# fail-fast is true by default
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
|
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
|
||||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
|
id-token: write
|
||||||
with:
|
with:
|
||||||
hardware: ${{ matrix.hardware }}
|
hardware: ${{ matrix.hardware }}
|
||||||
# https://github.com/actions/runner/issues/2206
|
# https://github.com/actions/runner/issues/2206
|
||||||
|
53
.github/workflows/nix_build.yaml
vendored
Normal file
53
.github/workflows/nix_build.yaml
vendored
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
name: "Nix Build Docker image"
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- 'main'
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
concurrency:
|
||||||
|
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_nix_image:
|
||||||
|
runs-on:
|
||||||
|
group: aws-highmemory-32-plus-priv
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
- uses: cachix/cachix-action@v14
|
||||||
|
with:
|
||||||
|
name: text-generation-inference
|
||||||
|
# If you chose signing key for write access
|
||||||
|
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||||
|
env:
|
||||||
|
USER: github_runner
|
||||||
|
- name: Build
|
||||||
|
run: nix build .#dockerImage
|
||||||
|
- name: Initialize Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
install: true
|
||||||
|
buildkitd-config: /tmp/buildkitd.toml
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Login to internal Container Registry
|
||||||
|
# if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
registry: registry.internal.huggingface.tech
|
||||||
|
- name: Push to docker
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
|
else
|
||||||
|
export TAG=${{ github.ref_name }}-nix
|
||||||
|
fi
|
||||||
|
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
|
||||||
|
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
|
1
.github/workflows/nix_tests.yaml
vendored
1
.github/workflows/nix_tests.yaml
vendored
@ -7,6 +7,7 @@ on:
|
|||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
concurrency:
|
concurrency:
|
||||||
|
20
.github/workflows/tests.yaml
vendored
20
.github/workflows/tests.yaml
vendored
@ -8,6 +8,7 @@ on:
|
|||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
|
|
||||||
@ -20,19 +21,14 @@ jobs:
|
|||||||
runs-on:
|
runs-on:
|
||||||
group: aws-highmemory-32-plus-priv
|
group: aws-highmemory-32-plus-priv
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
id: python
|
id: python
|
||||||
with:
|
with:
|
||||||
python-version: 3.11
|
python-version: 3.11
|
||||||
- name: Install Rust
|
- uses: dtolnay/rust-toolchain@1.85.0
|
||||||
uses: actions-rs/toolchain@v1
|
|
||||||
with:
|
with:
|
||||||
# Released on: 02 May, 2024
|
|
||||||
# https://releases.rs/docs/1.78.0/
|
|
||||||
toolchain: 1.80.0
|
|
||||||
override: true
|
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
uses: arduino/setup-protoc@v1
|
uses: arduino/setup-protoc@v1
|
||||||
@ -44,10 +40,18 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install python3.11-dev -y
|
sudo apt install python3.11-dev -y
|
||||||
|
pip install -U pip uv
|
||||||
|
uv venv
|
||||||
|
source ./.venv/bin/activate
|
||||||
make install-cpu
|
make install-cpu
|
||||||
|
- name: Download locked kernels
|
||||||
|
run: |
|
||||||
|
source ./.venv/bin/activate
|
||||||
|
kernels download server
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
source ./.venv/bin/activate
|
||||||
|
uv pip install pytest
|
||||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
|
15
.github/workflows/trufflehog.yaml
vendored
15
.github/workflows/trufflehog.yaml
vendored
@ -10,9 +10,12 @@ jobs:
|
|||||||
trufflehog:
|
trufflehog:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Secret Scanning
|
- name: Secret Scanning
|
||||||
uses: trufflesecurity/trufflehog@main
|
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
||||||
|
with:
|
||||||
|
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
||||||
|
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
||||||
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -23,3 +23,9 @@ server/fbgemmm
|
|||||||
|
|
||||||
.direnv/
|
.direnv/
|
||||||
.venv/
|
.venv/
|
||||||
|
|
||||||
|
# Gaudi auto-generated files
|
||||||
|
hl-smi_log*.txt
|
||||||
|
.graph_dumps
|
||||||
|
out
|
||||||
|
hqt_output
|
||||||
|
1772
Cargo.lock
generated
1772
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
33
Cargo.toml
33
Cargo.toml
@ -1,26 +1,27 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
"backends/v2",
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"backends/trtllm",
|
"backends/trtllm",
|
||||||
"launcher",
|
"backends/llamacpp",
|
||||||
"router"
|
"launcher",
|
||||||
|
"router"
|
||||||
]
|
]
|
||||||
default-members = [
|
default-members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
"backends/v2",
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
# "backends/trtllm",
|
# "backends/trtllm",
|
||||||
"launcher",
|
"launcher",
|
||||||
"router"
|
"router"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "3.0.0"
|
version = "3.2.3-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
@ -28,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
base64 = "0.22.0"
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.20.0", features = ["http"] }
|
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.4.2", features = ["tokio"] }
|
||||||
metrics = { version = "0.23.0" }
|
metrics = { version = "0.23.0" }
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
minijinja = { version = "2.2.0", features = ["json"] }
|
minijinja = { version = "2.2.0", features = ["json"] }
|
||||||
|
152
Dockerfile
152
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -45,21 +45,16 @@ RUN cargo build --profile release-opt --frozen
|
|||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
||||||
|
WORKDIR /usr/src/
|
||||||
|
|
||||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||||
ARG PYTORCH_VERSION=2.4.0
|
ARG PYTORCH_VERSION=2.6
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.11
|
ARG PYTHON_VERSION=3.11
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG CUDA_VERSION=12.4
|
|
||||||
ARG MAMBA_VERSION=24.3.0-0
|
|
||||||
ARG CUDA_CHANNEL=nvidia
|
|
||||||
ARG INSTALL_CHANNEL=pytorch
|
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
@ -67,26 +62,12 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
git && \
|
git && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
# Install conda
|
ENV PATH="$PATH:/root/.local/bin"
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
RUN case ${TARGETPLATFORM} in \
|
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
esac && \
|
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
||||||
rm ~/mambaforge.sh
|
|
||||||
|
|
||||||
# Install pytorch
|
|
||||||
# On arm64 we exit with an error code
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") exit 1 ;; \
|
|
||||||
*) /opt/conda/bin/conda update -y conda && \
|
|
||||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
|
||||||
esac && \
|
|
||||||
/opt/conda/bin/conda clean -ya
|
|
||||||
|
|
||||||
# CUDA kernels builder image
|
# CUDA kernels builder image
|
||||||
FROM pytorch-install AS kernel-builder
|
FROM pytorch-install AS kernel-builder
|
||||||
@ -106,7 +87,7 @@ WORKDIR /usr/src
|
|||||||
COPY server/Makefile-flash-att Makefile
|
COPY server/Makefile-flash-att Makefile
|
||||||
|
|
||||||
# Build specific version of flash attention
|
# Build specific version of flash attention
|
||||||
RUN make build-flash-attention
|
RUN . .venv/bin/activate && make build-flash-attention
|
||||||
|
|
||||||
# Build Flash Attention v2 CUDA kernels
|
# Build Flash Attention v2 CUDA kernels
|
||||||
FROM kernel-builder AS flash-att-v2-builder
|
FROM kernel-builder AS flash-att-v2-builder
|
||||||
@ -116,14 +97,14 @@ WORKDIR /usr/src
|
|||||||
COPY server/Makefile-flash-att-v2 Makefile
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
# Build specific version of flash attention v2
|
||||||
RUN make build-flash-attention-v2-cuda
|
RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder AS exllama-kernels-builder
|
FROM kernel-builder AS exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
RUN python setup.py build
|
RUN . .venv/bin/activate && python setup.py build
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder AS exllamav2-kernels-builder
|
FROM kernel-builder AS exllamav2-kernels-builder
|
||||||
@ -131,54 +112,43 @@ WORKDIR /usr/src
|
|||||||
COPY server/Makefile-exllamav2/ Makefile
|
COPY server/Makefile-exllamav2/ Makefile
|
||||||
|
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN make build-exllamav2
|
RUN . .venv/bin/activate && make build-exllamav2
|
||||||
|
|
||||||
# Build Transformers awq kernels
|
# Build Transformers awq kernels
|
||||||
FROM kernel-builder AS awq-kernels-builder
|
FROM kernel-builder AS awq-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-awq Makefile
|
COPY server/Makefile-awq Makefile
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN make build-awq
|
RUN . .venv/bin/activate && make build-awq
|
||||||
|
|
||||||
# Build eetq kernels
|
|
||||||
FROM kernel-builder AS eetq-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/Makefile-eetq Makefile
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN make build-eetq
|
|
||||||
|
|
||||||
# Build Lorax Punica kernels
|
# Build Lorax Punica kernels
|
||||||
FROM kernel-builder AS lorax-punica-builder
|
FROM kernel-builder AS lorax-punica-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-lorax-punica Makefile
|
COPY server/Makefile-lorax-punica Makefile
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder AS custom-kernels-builder
|
FROM kernel-builder AS custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN . .venv/bin/activate && python setup.py build
|
||||||
|
|
||||||
# Build mamba kernels
|
# Build mamba kernels
|
||||||
FROM kernel-builder AS mamba-builder
|
FROM kernel-builder AS mamba-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-selective-scan Makefile
|
COPY server/Makefile-selective-scan Makefile
|
||||||
RUN make build-all
|
RUN . .venv/bin/activate && make build-all
|
||||||
|
|
||||||
# Build flashinfer
|
# Build flashinfer
|
||||||
FROM kernel-builder AS flashinfer-builder
|
FROM kernel-builder AS flashinfer-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-flashinfer Makefile
|
COPY server/Makefile-flashinfer Makefile
|
||||||
RUN make install-flashinfer
|
RUN . .venv/bin/activate && make install-flashinfer
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
|
||||||
|
|
||||||
# Conda env
|
|
||||||
ENV PATH=/opt/conda/bin:$PATH \
|
|
||||||
CONDA_PREFIX=/opt/conda
|
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
@ -195,50 +165,61 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
git \
|
git \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy conda with PyTorch installed
|
# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
# ENV PATH="$PATH:/root/.local/bin"
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
# Copy build artifacts from flash attention builder
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
|
||||||
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from awq kernels builder
|
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from eetq kernels builder
|
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from lorax punica kernels builder
|
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
# Copy build artifacts from mamba builder
|
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
|
||||||
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
|
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
# RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
# Copy env with PyTorch installed
|
||||||
|
COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
|
||||||
|
ENV PYTHON_VERSION=3.11
|
||||||
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV HF_KERNELS_CACHE=/kernels
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
|
||||||
pip install -r requirements_cuda.txt && \
|
make gen-server-raw && \
|
||||||
pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
kernels download .
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
RUN cd server && \
|
||||||
|
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
|
||||||
|
uv pip install nvidia-nccl-cu12==2.25.1 && \
|
||||||
|
pwd && \
|
||||||
|
text-generation-server --help
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention builder
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention v2 builder
|
||||||
|
COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from custom kernels builder
|
||||||
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from exllama kernels builder
|
||||||
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
|
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from awq kernels builder
|
||||||
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from lorax punica kernels builder
|
||||||
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from mamba builder
|
||||||
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
|
||||||
|
|
||||||
|
|
||||||
|
# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
# Required to find libpython within the rust binaries
|
# Required to find libpython within the rust binaries
|
||||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
|
||||||
# This is needed because exl2 tries to load flash-attn
|
# This is needed because exl2 tries to load flash-attn
|
||||||
# And fails with our builds.
|
# And fails with our builds.
|
||||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||||
@ -273,5 +254,6 @@ FROM base
|
|||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
# CMD ["--json-output"]
|
# CMD ["--json-output"]
|
||||||
|
167
Dockerfile.neuron
Normal file
167
Dockerfile.neuron
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
# Fetch and extract the TGI sources
|
||||||
|
FROM alpine AS tgi
|
||||||
|
RUN mkdir -p /tgi
|
||||||
|
|
||||||
|
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
||||||
|
FROM alpine AS optimum-neuron
|
||||||
|
RUN mkdir -p /optimum-neuron
|
||||||
|
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||||
|
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
||||||
|
|
||||||
|
# Build cargo components (adapted from TGI original Dockerfile)
|
||||||
|
# Note: we cannot use the cargo-chef base image as it uses python 3.11
|
||||||
|
FROM ubuntu:22.04 AS chef
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
curl ca-certificates build-essential \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY backends/neuron/Cargo.toml Cargo.toml
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
unzip python3-dev libssl-dev pkg-config \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY backends/neuron/Cargo.toml Cargo.toml
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
# Python base image
|
||||||
|
FROM ubuntu:22.04 AS base
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
python3-pip \
|
||||||
|
python3-setuptools \
|
||||||
|
python-is-python3 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
RUN pip3 --no-cache-dir install --upgrade pip
|
||||||
|
|
||||||
|
# Python server build image
|
||||||
|
FROM base AS pyserver
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
make \
|
||||||
|
python3-venv \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN install -d /pyserver
|
||||||
|
WORKDIR /pyserver
|
||||||
|
COPY backends/neuron/server server
|
||||||
|
COPY proto proto
|
||||||
|
RUN pip3 install -r server/build-requirements.txt
|
||||||
|
RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
|
||||||
|
|
||||||
|
# Neuron base image (used for deployment)
|
||||||
|
FROM base AS neuron
|
||||||
|
|
||||||
|
# Install system prerequisites
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
gnupg2 \
|
||||||
|
wget \
|
||||||
|
python3-dev \
|
||||||
|
libexpat1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
|
||||||
|
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
|
||||||
|
|
||||||
|
# Install neuronx packages
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
aws-neuronx-dkms=2.19.64.0 \
|
||||||
|
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
|
||||||
|
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
|
||||||
|
aws-neuronx-tools=2.20.204.0 \
|
||||||
|
libxml2 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
|
||||||
|
|
||||||
|
# Install manually torch CPU version to avoid pulling CUDA
|
||||||
|
RUN pip3 install \
|
||||||
|
torch==2.5.1 \
|
||||||
|
torchvision==0.20.1 \
|
||||||
|
--index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
RUN pip3 install \
|
||||||
|
neuronx-cc==2.16.372.0 \
|
||||||
|
torch-neuronx==2.5.1.2.4.0 \
|
||||||
|
transformers-neuronx==0.13.322 \
|
||||||
|
neuronx-distributed==0.10.1 \
|
||||||
|
libneuronxla==2.1.681.0 \
|
||||||
|
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
||||||
|
|
||||||
|
# Install HuggingFace packages
|
||||||
|
RUN pip3 install \
|
||||||
|
hf_transfer huggingface_hub
|
||||||
|
|
||||||
|
# Install optimum-neuron
|
||||||
|
COPY --from=optimum-neuron /optimum-neuron optimum-neuron
|
||||||
|
RUN pip3 install ./optimum-neuron
|
||||||
|
|
||||||
|
# TGI base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/tmp \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Disable color logs as they are not supported by CloudWatch
|
||||||
|
ENV LOGURU_COLORIZE=NO
|
||||||
|
ENV LOG_COLORIZE=0
|
||||||
|
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
# Install python server
|
||||||
|
COPY --from=pyserver /pyserver/build/dist dist
|
||||||
|
RUN pip install dist/text_generation_server*.tar.gz
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM neuron
|
||||||
|
|
||||||
|
COPY backends/neuron/tgi_env.py /tgi_env.py
|
||||||
|
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
378
Dockerfile_amd
378
Dockerfile_amd
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -41,262 +41,237 @@ COPY backends backends
|
|||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt --frozen
|
RUN cargo build --profile release-opt --frozen
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
|
||||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
|
||||||
|
|
||||||
|
ARG HIPBLASLT_BRANCH="4d40e36"
|
||||||
|
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||||
|
ARG LEGACY_HIPBLASLT_OPTION=
|
||||||
|
ARG RCCL_BRANCH="648a58d"
|
||||||
|
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||||
|
ARG TRITON_BRANCH="e5be006"
|
||||||
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||||
|
ARG PYTORCH_BRANCH="3a585126"
|
||||||
|
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||||
|
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||||
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||||
|
ARG FA_BRANCH="b7d29fb"
|
||||||
|
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||||
|
ARG AITER_BRANCH="21d47a9"
|
||||||
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
|
|
||||||
|
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||||
|
ENV ROCM_PATH=/opt/rocm
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||||
|
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
||||||
|
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||||
|
|
||||||
|
ARG PYTHON_VERSION=3.11
|
||||||
|
|
||||||
|
RUN mkdir -p /app
|
||||||
|
WORKDIR /app
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
ccache \
|
ccache \
|
||||||
curl \
|
curl \
|
||||||
git \
|
git \
|
||||||
make \
|
ninja-build \
|
||||||
libmsgpack-dev \
|
cmake \
|
||||||
libssl-dev \
|
software-properties-common \
|
||||||
llvm-dev \
|
python3.11-dev \
|
||||||
g++ \
|
python3.11-venv && \
|
||||||
# Needed to build VLLM & flash.
|
rm -rf /var/lib/apt/lists/*
|
||||||
rocthrust-dev \
|
|
||||||
hipsparse-dev \
|
|
||||||
hipblas-dev \
|
|
||||||
hipcub-dev \
|
|
||||||
rocblas-dev \
|
|
||||||
hiprand-dev \
|
|
||||||
hipfft-dev \
|
|
||||||
rocrand-dev \
|
|
||||||
miopen-hip-dev \
|
|
||||||
hipsolver-dev \
|
|
||||||
rccl-dev \
|
|
||||||
cmake \
|
|
||||||
python3.11-venv && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ENV PATH="$PATH:/root/.local/bin"
|
||||||
ARG PYTHON_VERSION='3.11.10'
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
# Automatically set by buildx
|
RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
|
||||||
ARG TARGETPLATFORM
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
ENV PATH=/opt/conda/bin:$PATH
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
|
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
|
||||||
# Install mamba
|
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
|
||||||
esac && \
|
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
||||||
mamba init && \
|
|
||||||
rm ~/mambaforge.sh
|
|
||||||
|
|
||||||
# RUN conda install intel::mkl-static intel::mkl-include
|
|
||||||
# Install pytorch
|
|
||||||
# On arm64 we exit with an error code
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") exit 1 ;; \
|
|
||||||
*) /opt/conda/bin/conda update -y conda && \
|
|
||||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
|
||||||
esac && \
|
|
||||||
/opt/conda/bin/conda clean -ya
|
|
||||||
|
|
||||||
# Install flash-attention, torch dependencies
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN conda install mkl=2021
|
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
|
||||||
|
|
||||||
|
|
||||||
ARG COMMON_WORKDIR=/
|
|
||||||
WORKDIR ${COMMON_WORKDIR}
|
|
||||||
|
|
||||||
|
|
||||||
# Install HIPBLASLt
|
|
||||||
FROM base AS build_hipblaslt
|
FROM base AS build_hipblaslt
|
||||||
ARG HIPBLASLT_BRANCH="e6da924"
|
ARG HIPBLASLT_BRANCH
|
||||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
ARG HIPBLAS_COMMON_BRANCH
|
||||||
&& cd hipBLASLt \
|
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||||
|
ARG LEGACY_HIPBLASLT_OPTION
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||||
|
RUN . .venv/bin/activate && cd hipBLAS-common \
|
||||||
|
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||||
|
&& mkdir build \
|
||||||
|
&& cd build \
|
||||||
|
&& cmake .. \
|
||||||
|
&& make package \
|
||||||
|
&& dpkg -i ./*.deb
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||||
|
RUN . .venv/bin/activate && cd hipBLASLt \
|
||||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||||
&& cd build/release \
|
&& cd build/release \
|
||||||
&& make package
|
&& make package
|
||||||
|
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||||
|
|
||||||
FROM scratch AS export_hipblaslt
|
|
||||||
ARG COMMON_WORKDIR
|
|
||||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
|
||||||
|
|
||||||
# RCCL build stages
|
|
||||||
FROM base AS build_rccl
|
FROM base AS build_rccl
|
||||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
ARG RCCL_BRANCH
|
||||||
RUN git clone https://github.com/ROCm/rccl \
|
ARG RCCL_REPO
|
||||||
&& cd rccl \
|
RUN git clone ${RCCL_REPO}
|
||||||
|
RUN . .venv/bin/activate && cd rccl \
|
||||||
&& git checkout ${RCCL_BRANCH} \
|
&& git checkout ${RCCL_BRANCH} \
|
||||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||||
FROM scratch AS export_rccl
|
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||||
ARG COMMON_WORKDIR
|
|
||||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
|
||||||
|
|
||||||
# Triton build stages
|
|
||||||
FROM base AS build_triton
|
FROM base AS build_triton
|
||||||
ARG TRITON_BRANCH="e192dba"
|
ARG TRITON_BRANCH
|
||||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
ARG TRITON_REPO
|
||||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
RUN git clone ${TRITON_REPO}
|
||||||
&& cd triton \
|
RUN . .venv/bin/activate && cd triton \
|
||||||
&& git checkout ${TRITON_BRANCH} \
|
&& git checkout ${TRITON_BRANCH} \
|
||||||
&& cd python \
|
&& cd python \
|
||||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
FROM scratch AS export_triton
|
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||||
ARG COMMON_WORKDIR
|
|
||||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
|
||||||
|
|
||||||
# # AMD-SMI build stages
|
|
||||||
FROM base AS build_amdsmi
|
FROM base AS build_amdsmi
|
||||||
RUN cd /opt/rocm/share/amd_smi \
|
RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
|
||||||
&& pip wheel . --wheel-dir=dist
|
&& pip wheel . --wheel-dir=dist
|
||||||
FROM scratch AS export_amdsmi
|
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
|
||||||
|
|
||||||
|
FROM base AS build_pytorch
|
||||||
|
ARG PYTORCH_BRANCH
|
||||||
|
ARG PYTORCH_VISION_BRANCH
|
||||||
|
ARG PYTORCH_REPO
|
||||||
|
ARG PYTORCH_VISION_REPO
|
||||||
|
ARG FA_BRANCH
|
||||||
|
ARG FA_REPO
|
||||||
|
RUN git clone ${PYTORCH_REPO} pytorch
|
||||||
|
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
||||||
|
pip install -r requirements.txt && git submodule update --init --recursive \
|
||||||
|
&& python3 tools/amd_build/build_amd.py \
|
||||||
|
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
||||||
|
&& pip install dist/*.whl
|
||||||
|
RUN git clone ${PYTORCH_VISION_REPO} vision
|
||||||
|
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
||||||
|
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||||
|
&& pip install dist/*.whl
|
||||||
|
RUN git clone ${FA_REPO}
|
||||||
|
RUN . .venv/bin/activate && cd flash-attention \
|
||||||
|
&& git checkout ${FA_BRANCH} \
|
||||||
|
&& git submodule update --init \
|
||||||
|
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
||||||
|
&& cp /app/vision/dist/*.whl /app/install \
|
||||||
|
&& cp /app/flash-attention/dist/*.whl /app/install
|
||||||
|
|
||||||
FROM base as build_pytorch
|
FROM base AS final
|
||||||
|
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||||
|
dpkg -i /install/*deb \
|
||||||
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||||
|
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||||
|
dpkg -i /install/*deb \
|
||||||
|
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||||
|
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
|
||||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
ARG AITER_REPO
|
||||||
if ls /install/*.deb; then \
|
ARG AITER_BRANCH
|
||||||
dpkg -i /install/*.deb \
|
RUN git clone --recursive ${AITER_REPO}
|
||||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
RUN . .venv/bin/activate && cd aiter \
|
||||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
&& git checkout ${AITER_BRANCH} \
|
||||||
fi
|
&& git submodule update --init --recursive \
|
||||||
|
&& pip install -r requirements.txt \
|
||||||
|
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
|
||||||
|
|
||||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
RUN rm -rf /var/lib/apt/lists/*
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
|
||||||
|
|
||||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
|
||||||
# Not yet in 2.5.0-rc1
|
|
||||||
ARG PYTORCH_BRANCH="cedc116"
|
|
||||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
|
||||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
|
||||||
|
|
||||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
|
||||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
|
||||||
&& pip install -r requirements.txt --no-cache-dir \
|
|
||||||
&& python tools/amd_build/build_amd.py \
|
|
||||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
|
||||||
FROM scratch as export_pytorch
|
|
||||||
ARG COMMON_WORKDIR
|
|
||||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
|
||||||
|
|
||||||
FROM base AS install_deps
|
|
||||||
|
|
||||||
ARG COMMON_WORKDIR
|
|
||||||
|
|
||||||
# Install hipblaslt
|
|
||||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
|
||||||
if ls /install/*.deb; then \
|
|
||||||
dpkg -i /install/*.deb \
|
|
||||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
|
||||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
|
||||||
if ls /install/*.deb; then \
|
|
||||||
dpkg -i /install/*.deb \
|
|
||||||
# RCCL needs to be installed twice
|
|
||||||
&& dpkg -i /install/*.deb \
|
|
||||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
|
||||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
|
||||||
if ls /install/*.whl; then \
|
|
||||||
# Preemptively uninstall to prevent pip same-version no-installs
|
|
||||||
pip uninstall -y triton \
|
|
||||||
&& pip install /install/*.whl; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
|
||||||
# Preemptively uninstall to prevent pip same-version no-installs
|
|
||||||
pip uninstall -y amdsmi \
|
|
||||||
&& pip install /install/*.whl;
|
|
||||||
|
|
||||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
|
||||||
if ls /install/*.whl; then \
|
|
||||||
# Preemptively uninstall to prevent pip same-version no-installs
|
|
||||||
pip uninstall -y torch torchvision \
|
|
||||||
&& pip install /install/*.whl; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
FROM install_deps AS kernel-builder
|
|
||||||
|
|
||||||
|
FROM final AS kernel-builder
|
||||||
# # Build vllm kernels
|
# # Build vllm kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
COPY server/Makefile-vllm Makefile
|
||||||
|
RUN . .venv/bin/activate && pip install setuptools_scm
|
||||||
|
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-rocm
|
RUN . .venv/bin/activate && make build-vllm-rocm
|
||||||
|
|
||||||
# Build Flash Attention v2 kernels
|
|
||||||
FROM kernel-builder AS flash-att-v2-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-flash-att-v2 Makefile
|
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
|
||||||
RUN make build-flash-attention-v2-rocm
|
|
||||||
|
|
||||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||||
FROM kernel-builder AS custom-kernels-builder
|
FROM kernel-builder AS custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN python setup.py build
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
# Build exllama kernels
|
# Build exllama kernels
|
||||||
FROM kernel-builder AS exllama-kernels-builder
|
FROM kernel-builder AS exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
RUN python setup.py build
|
|
||||||
|
|
||||||
# Build exllama v2 kernels
|
# Build exllama v2 kernels
|
||||||
FROM kernel-builder AS exllamav2-kernels-builder
|
FROM kernel-builder AS exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/exllamav2_kernels/ .
|
||||||
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
RUN python setup.py build
|
FROM kernel-builder AS marlin-kernels
|
||||||
|
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||||
|
cd marlin-kernels && \
|
||||||
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
FROM install_deps AS base-copy
|
FROM kernel-builder AS moe-kernels
|
||||||
|
ENV MOE_KERNELS_BRANCH=v0.8.2
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
|
||||||
|
cd moe-kernels && \
|
||||||
|
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
|
FROM final AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
ENV VIRTUAL_ENV=/app/.venv/
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
ENV PATH="$PATH:/app/.venv/bin/"
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
uv pip install grpcio-tools mypy-protobuf && \
|
||||||
pip install -r requirements_rocm.txt && \
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
make gen-server-raw
|
||||||
|
RUN cd server && \
|
||||||
|
pwd && \
|
||||||
|
text-generation-server --help
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -304,7 +279,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base AS sagemaker
|
FROM base AS sagemaker
|
||||||
@ -335,4 +309,6 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
|
||||||
|
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
|
||||||
|
# CMD ["--json-output"]
|
||||||
|
126
Dockerfile_gaudi
Normal file
126
Dockerfile_gaudi
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# Those arguments are required to build the image
|
||||||
|
ARG HABANA_VERSION=1.20.0
|
||||||
|
ARG PYTORCH_VERSION=2.6.0
|
||||||
|
|
||||||
|
# Rust builder
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ENV PYO3_PYTHON="/root/.local/bin/python" \
|
||||||
|
PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
|
||||||
|
PYO3_PYTHON_VERSION="3.10"
|
||||||
|
|
||||||
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||||
|
&& . $HOME/.local/bin/env \
|
||||||
|
&& uv python install 3.10 --default --preview \
|
||||||
|
&& test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
# Text Generation Inference base image
|
||||||
|
ARG HABANA_VERSION
|
||||||
|
ARG PYTORCH_VERSION
|
||||||
|
|
||||||
|
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||||
|
|
||||||
|
ENV ATTENTION=default
|
||||||
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HF_HOME=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
|
||||||
|
RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
|
||||||
|
|
||||||
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
libssl-dev \
|
||||||
|
ca-certificates \
|
||||||
|
make \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY backends/gaudi/server server
|
||||||
|
COPY backends/gaudi/server/Makefile server/Makefile
|
||||||
|
ARG HABANA_VERSION
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install --no-deps -r requirements.txt && \
|
||||||
|
bash ./dill-0.3.8-patch.sh && \
|
||||||
|
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||||
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
|
pip install . --no-cache-dir
|
||||||
|
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
|
||||||
|
# AWS Sagemaker compatible image
|
||||||
|
FROM base AS sagemaker
|
||||||
|
|
||||||
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["./entrypoint.sh"]
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM base
|
||||||
|
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||||
|
ENV HABANA_VISIBLE_DEVICES all
|
||||||
|
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
||||||
|
|
||||||
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
|
CMD ["--json-output"]
|
@ -1,6 +1,6 @@
|
|||||||
ARG PLATFORM=xpu
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
|
|||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
|
|
||||||
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
|
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
|
|||||||
|
|
||||||
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
@ -96,29 +96,28 @@ ENV HF_HOME=/data \
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
|
||||||
|
|
||||||
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
|
WORKDIR /usr/src
|
||||||
|
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
#ENV TORCH_LLM_ALLREDUCE=1
|
ENV TORCH_LLM_ALLREDUCE=1
|
||||||
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
|
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||||
|
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
@ -158,7 +157,7 @@ ARG MAMBA_VERSION=23.1.0-1
|
|||||||
ARG PYTHON_VERSION='3.11.10'
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
# Install mamba
|
# Install mamba
|
||||||
@ -181,22 +180,14 @@ RUN case ${TARGETPLATFORM} in \
|
|||||||
|
|
||||||
RUN conda install -c conda-forge gperftools mkl
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
|
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
RUN pip install triton==3.1.0 py-libnuma
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
|
||||||
|
|
||||||
RUN pip install triton py-libnuma
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout b7b552baf64283b594665b8687430fe92990e497
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
|
||||||
RUN sed -i 's/VERSION_MINOR 6/VERSION_MINOR 5/' intel-extension-for-pytorch/version.txt
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
|
||||||
|
|
||||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||||
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||||
@ -209,10 +200,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -222,9 +214,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=flashdecoding-ipex
|
||||||
ENV PREFIX_CACHING=0
|
ENV PREFIX_CACHING=1
|
||||||
ENV PREFILL_CHUNKING=0
|
ENV PREFILL_CHUNKING=1
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
88
Dockerfile_llamacpp
Normal file
88
Dockerfile_llamacpp
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
|
||||||
|
|
||||||
|
ARG llamacpp_version=b4827
|
||||||
|
ARG llamacpp_cuda=OFF
|
||||||
|
ARG llamacpp_native=ON
|
||||||
|
ARG llamacpp_cpu_arm_arch=native
|
||||||
|
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
|
||||||
|
|
||||||
|
WORKDIR /opt/src
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
|
clang \
|
||||||
|
cmake \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
python3-dev \
|
||||||
|
libssl-dev \
|
||||||
|
pkg-config \
|
||||||
|
tar
|
||||||
|
|
||||||
|
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
|
||||||
|
RUN mkdir -p llama.cpp \
|
||||||
|
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
|
||||||
|
&& cd llama.cpp \
|
||||||
|
&& cmake -B build \
|
||||||
|
-DCMAKE_INSTALL_PREFIX=/usr \
|
||||||
|
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
||||||
|
-DCMAKE_C_COMPILER=clang \
|
||||||
|
-DCMAKE_CXX_COMPILER=clang++ \
|
||||||
|
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
|
||||||
|
-DGGML_CUDA=${llamacpp_cuda} \
|
||||||
|
-DGGML_NATIVE=${llamacpp_native} \
|
||||||
|
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
|
||||||
|
-DLLAMA_BUILD_COMMON=OFF \
|
||||||
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
|
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||||
|
-DLLAMA_BUILD_SERVER=OFF \
|
||||||
|
&& cmake --build build --parallel --config Release \
|
||||||
|
&& cmake --install build
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
FROM deps AS planner
|
||||||
|
COPY . .
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM deps AS builder
|
||||||
|
COPY --from=planner /app/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook \
|
||||||
|
--recipe-path recipe.json \
|
||||||
|
--profile release \
|
||||||
|
--package text-generation-router-llamacpp
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build \
|
||||||
|
--profile release \
|
||||||
|
--package text-generation-router-llamacpp --frozen
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
|
python3-venv \
|
||||||
|
python3-pip
|
||||||
|
|
||||||
|
RUN python3 -m venv /venv
|
||||||
|
ENV PATH="/venv/bin:$PATH"
|
||||||
|
|
||||||
|
COPY backends/llamacpp/requirements.txt requirements.txt
|
||||||
|
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
|
||||||
|
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
|
||||||
|
|
||||||
|
RUN pip3 install --no-cache-dir \
|
||||||
|
-r requirements.txt \
|
||||||
|
-e gguf-py
|
||||||
|
|
||||||
|
COPY --from=builder /usr/lib/libllama.so /usr/lib/
|
||||||
|
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
|
||||||
|
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
|
||||||
|
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||||
|
|
||||||
|
ENTRYPOINT ["text-generation-router-llamacpp"]
|
@ -1,52 +1,55 @@
|
|||||||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
|
||||||
ARG OMPI_VERSION="4.1.6"
|
ARG cuda_base=12.8.0
|
||||||
|
ARG build_type=release
|
||||||
# Build dependencies resolver stage
|
ARG ompi_version=4.1.7
|
||||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
ARG sccache_gha_enabled=off
|
||||||
WORKDIR /usr/src/text-generation-inference/backends/trtllm
|
ARG actions_results_url=""
|
||||||
|
ARG actions_runtime_token=""
|
||||||
FROM chef AS planner
|
|
||||||
COPY . .
|
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
# CUDA dependent dependencies resolver stage
|
||||||
FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
|
||||||
apt update && apt install -y \
|
|
||||||
build-essential \
|
build-essential \
|
||||||
cmake \
|
cmake \
|
||||||
curl \
|
curl \
|
||||||
gcc \
|
gcc-14 \
|
||||||
g++ \
|
g++-14 \
|
||||||
git \
|
git \
|
||||||
git-lfs \
|
git-lfs \
|
||||||
|
lld \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
libucx-dev \
|
||||||
|
libasan8 \
|
||||||
|
libubsan1 \
|
||||||
ninja-build \
|
ninja-build \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
|
pipx \
|
||||||
python3 \
|
python3 \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
python3-setuptools \
|
python3-setuptools \
|
||||||
tar \
|
tar \
|
||||||
wget
|
wget --no-install-recommends && \
|
||||||
|
pipx ensurepath
|
||||||
|
|
||||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||||
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||||
|
|
||||||
# Install OpenMPI
|
# Install OpenMPI
|
||||||
FROM cuda-builder AS mpi-builder
|
FROM cuda-builder AS mpi-builder
|
||||||
ARG OMPI_VERSION
|
WORKDIR /opt/src/mpi
|
||||||
|
|
||||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
ARG ompi_version
|
||||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
ENV OMPI_VERSION=${ompi_version}
|
||||||
mkdir /usr/src/mpi && \
|
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
|
||||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
|
||||||
cd /usr/src/mpi && \
|
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
|
||||||
|
|
||||||
|
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
|
||||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
||||||
make -j all && \
|
make -j all && \
|
||||||
make install && \
|
make install && \
|
||||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
rm -rf ${OMPI_TARBALL_FILENAME}/..
|
||||||
|
|
||||||
# Install TensorRT
|
# Install TensorRT
|
||||||
FROM cuda-builder AS trt-builder
|
FROM cuda-builder AS trt-builder
|
||||||
@ -58,38 +61,62 @@ RUN chmod +x /opt/install_tensorrt.sh && \
|
|||||||
FROM cuda-builder AS tgi-builder
|
FROM cuda-builder AS tgi-builder
|
||||||
WORKDIR /usr/src/text-generation-inference
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
# Scoped global args reuse
|
||||||
|
ARG cuda_arch_list
|
||||||
|
ARG build_type
|
||||||
|
ARG sccache_gha_enabled
|
||||||
|
ARG actions_results_url
|
||||||
|
ARG actions_runtime_token
|
||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
|
||||||
chmod -R a+w /root/.rustup && \
|
|
||||||
chmod -R a+w /root/.cargo
|
|
||||||
|
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
RUN cargo install cargo-chef
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
|
||||||
|
chmod -R a+w /root/.rustup && \
|
||||||
|
chmod -R a+w /root/.cargo && \
|
||||||
|
cargo install sccache --version ">=0.10.0" --locked
|
||||||
|
|
||||||
# Cache dependencies
|
|
||||||
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
|
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
|
||||||
|
|
||||||
# Build actual TGI
|
|
||||||
ARG CUDA_ARCH_LIST
|
|
||||||
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
|
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
|
||||||
|
|
||||||
COPY . .
|
ENV USE_LLD_LINKER=ON
|
||||||
|
ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
||||||
|
|
||||||
|
# SCCACHE Specifics args - before finding a better, more generic, way...
|
||||||
|
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
||||||
|
ENV ACTIONS_RESULTS_URL=${actions_results_url}
|
||||||
|
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY launcher launcher
|
||||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
|
||||||
cd backends/trtllm && \
|
|
||||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
|
||||||
|
|
||||||
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
|
ENV RUSTC_WRAPPER=sccache
|
||||||
RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
|
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
|
||||||
|
RUN export CC=gcc-14 \
|
||||||
|
export CXX=g++-14 \
|
||||||
|
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
|
||||||
|
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||||
|
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
|
||||||
|
sccache --show-stats
|
||||||
|
|
||||||
|
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
|
||||||
|
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
python3 -m pip install transformers tokenizers
|
pipx ensurepath && \
|
||||||
|
pipx install --include-deps transformers tokenizers
|
||||||
|
|
||||||
WORKDIR /usr/local/tgi/bin
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
ENV TOKENIZERS_PARALLELISM=false
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
ENV OMPI_MCA_plm_rsh_agent=""
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
@ -99,10 +126,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
|||||||
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is used only for the CI/CD
|
||||||
|
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
|
||||||
|
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
|
pipx ensurepath && \
|
||||||
|
pipx install --include-deps transformers tokenizers
|
||||||
|
|
||||||
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
|
|
||||||
|
# Basically we copy from target/debug instead of target/release
|
||||||
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is the final image
|
||||||
FROM runtime
|
FROM runtime
|
||||||
|
|
||||||
LABEL co.huggingface.vendor="Hugging Face Inc."
|
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||||
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||||
|
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
|
||||||
|
|
||||||
ENTRYPOINT ["./text-generation-launcher"]
|
ENTRYPOINT ["./text-generation-launcher"]
|
||||||
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
||||||
|
3
Makefile
3
Makefile
@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -rf target aml
|
rm -rf target aml
|
||||||
|
|
||||||
|
preview_doc:
|
||||||
|
doc-builder preview text-generation-inference docs/source --not_python_module
|
||||||
|
34
README.md
34
README.md
@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
||||||
<img width=560 width=315 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
|
<img width=560 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
# Text Generation Inference
|
# Text Generation Inference
|
||||||
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
|||||||
volume=$PWD/data
|
volume=$PWD/data
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
|
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above.
|
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.3-rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
@ -141,8 +141,8 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
|
|||||||
For example, if you want to serve the gated Llama V2 model variants:
|
For example, if you want to serve the gated Llama V2 model variants:
|
||||||
|
|
||||||
1. Go to https://huggingface.co/settings/tokens
|
1. Go to https://huggingface.co/settings/tokens
|
||||||
2. Copy your cli READ token
|
2. Copy your CLI READ token
|
||||||
3. Export `HF_TOKEN=<your cli READ token>`
|
3. Export `HF_TOKEN=<your CLI READ token>`
|
||||||
|
|
||||||
or with Docker:
|
or with Docker:
|
||||||
|
|
||||||
@ -151,13 +151,14 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
|
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
|
|
||||||
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
|
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
|
||||||
`PyTorch` to do distributed training/inference. `text-generation-inference` make
|
`PyTorch` to do distributed training/inference. `text-generation-inference` makes
|
||||||
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
|
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
|
||||||
|
|
||||||
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
|
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
|
||||||
@ -196,14 +197,26 @@ Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with T
|
|||||||
|
|
||||||
You can also opt to install `text-generation-inference` locally.
|
You can also opt to install `text-generation-inference` locally.
|
||||||
|
|
||||||
First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
First clone the repository and change directory into it:
|
||||||
Python 3.9, e.g. using `conda`:
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/huggingface/text-generation-inference
|
||||||
|
cd text-generation-inference
|
||||||
|
```
|
||||||
|
|
||||||
|
Then [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
||||||
|
Python 3.9, e.g. using `conda` or `python venv`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||||
|
|
||||||
|
#using conda
|
||||||
conda create -n text-generation-inference python=3.11
|
conda create -n text-generation-inference python=3.11
|
||||||
conda activate text-generation-inference
|
conda activate text-generation-inference
|
||||||
|
|
||||||
|
#using python venv
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
You may also need to install Protoc.
|
You may also need to install Protoc.
|
||||||
@ -250,7 +263,8 @@ locally, which can take hours.
|
|||||||
After that you can run TGI with `nix run`:
|
After that you can run TGI with `nix run`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
nix run . -- --model-id meta-llama/Llama-3.1-8B-Instruct
|
cd text-generation-inference
|
||||||
|
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
|
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 201 KiB After Width: | Height: | Size: 209 KiB |
62
backends/gaudi/Makefile
Normal file
62
backends/gaudi/Makefile
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||||
|
mkfile_dir := $(dir $(mkfile_path))
|
||||||
|
root_dir := ${mkfile_dir}/../..
|
||||||
|
|
||||||
|
HABANA_VERSION := 1.20.0
|
||||||
|
PYTORCH_VERSION := 2.6.0
|
||||||
|
|
||||||
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
|
image:
|
||||||
|
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||||
|
|
||||||
|
run-local-dev-container:
|
||||||
|
docker run -it \
|
||||||
|
--runtime=habana \
|
||||||
|
--ipc=host \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--net=host \
|
||||||
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
|
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
|
||||||
|
-e LOG_LEVEL=debug \
|
||||||
|
-e PORT=8080 \
|
||||||
|
-v /home/ubuntu/.cache/huggingface:/data \
|
||||||
|
-v $(PWD):/text-generation-inference \
|
||||||
|
-w /text-generation-inference \
|
||||||
|
vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
|
||||||
|
|
||||||
|
install-dependencies:
|
||||||
|
pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
|
||||||
|
pip install outlines~=0.0.34
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
|
||||||
|
install-server:
|
||||||
|
make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
|
||||||
|
|
||||||
|
install-router:
|
||||||
|
make -C ${root_dir} install-router
|
||||||
|
|
||||||
|
install-launcher:
|
||||||
|
make -C ${root_dir} install-launcher
|
||||||
|
|
||||||
|
# use source to load the rust in path
|
||||||
|
local-dev-install: install-dependencies
|
||||||
|
bash -c 'source "$$HOME/.cargo/env" && \
|
||||||
|
make install-server && \
|
||||||
|
make install-router && \
|
||||||
|
make install-launcher'
|
||||||
|
|
||||||
|
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
|
||||||
|
run-integration-tests:
|
||||||
|
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
|
||||||
|
DOCKER_VOLUME=${root_dir}/data \
|
||||||
|
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||||
|
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
|
||||||
|
|
||||||
|
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
|
||||||
|
capture-expected-outputs-for-integration-tests:
|
||||||
|
DOCKER_VOLUME=${root_dir}/data \
|
||||||
|
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||||
|
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
|
142
backends/gaudi/README.md
Normal file
142
backends/gaudi/README.md
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Text-generation-inference - Gaudi backend
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
|
||||||
|
|
||||||
|
## Build your own image
|
||||||
|
|
||||||
|
The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
|
||||||
|
|
||||||
|
Option 1: From the project root directory:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi image
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: From the Gaudi backend directory:
|
||||||
|
```bash
|
||||||
|
cd backends/gaudi
|
||||||
|
make image
|
||||||
|
```
|
||||||
|
|
||||||
|
You can now run the server with the following command:
|
||||||
|
|
||||||
|
Option 1: Sharded:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||||
|
volume=${HOME}/.cache/huggingface
|
||||||
|
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Non-sharded:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||||
|
volume=${HOME}/.cache/huggingface
|
||||||
|
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
This is useful if you want to run the server locally for better debugging.
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi run-local-dev-container
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the following command inside the container to install tgi for gaudi:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi local-dev-install
|
||||||
|
```
|
||||||
|
|
||||||
|
Add rust to path:
|
||||||
|
```bash
|
||||||
|
. "$HOME/.cargo/env"
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 1: Run the server (sharded model):
|
||||||
|
```bash
|
||||||
|
LOG_LEVEL=debug text-generation-launcher \
|
||||||
|
--model-id meta-llama/Llama-3.1-8B-Instruct \
|
||||||
|
--sharded true \
|
||||||
|
--num-shard 8 \
|
||||||
|
--max-input-tokens 512 \
|
||||||
|
--max-total-tokens 1024 \
|
||||||
|
--max-batch-size 8 \
|
||||||
|
--max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Run the server (non-sharded model):
|
||||||
|
```bash
|
||||||
|
LOG_LEVEL=debug text-generation-launcher \
|
||||||
|
--model-id meta-llama/Llama-3.1-8B-Instruct \
|
||||||
|
--max-input-tokens 512 \
|
||||||
|
--max-total-tokens 1024 \
|
||||||
|
--max-batch-size 4 \
|
||||||
|
--max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then test the server with the following curl command from another terminal (can be outside the container):
|
||||||
|
```bash
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration tests
|
||||||
|
|
||||||
|
To run the integration tests, you need to first build the image:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi image
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the following command to run the integration tests:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi run-integration-tests
|
||||||
|
```
|
||||||
|
|
||||||
|
To capture the expected outputs for the integration tests, you can run the following command:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi capture-expected-outputs-for-integration-tests
|
||||||
|
```
|
||||||
|
|
||||||
|
#### How the integration tests works
|
||||||
|
The integration tests works as follows:
|
||||||
|
|
||||||
|
1. Start a tgi server in a container, similar to the command:
|
||||||
|
```bash
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Do a /generate request to the server, similar to the command:
|
||||||
|
```bash
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Check the output of the server against the expected output:
|
||||||
|
```python
|
||||||
|
assert curl_output == expected_output
|
||||||
|
```
|
||||||
|
|
||||||
|
This is the repeated for a set of models and configurations.
|
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
# Examples of Docker Commands for Gaudi Backend
|
||||||
|
|
||||||
|
This page gives a list of examples of docker run commands for some of the most popular models.
|
||||||
|
|
||||||
|
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
|
||||||
|
|
||||||
|
## Default Precision (BF16)
|
||||||
|
|
||||||
|
### Llama3.1-8B on 1 card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama3.1-70B 8 cards (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama2-7B on 1 Card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-7b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama2-70B on 8 cards (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-70b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llava-v1.6-Mistral-7B on 1 card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## FP8 Precision
|
||||||
|
|
||||||
|
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
|
||||||
|
|
||||||
|
## Llama3.1-8B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama3.1-70B on 8 cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama2-7B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-7b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama2-70B on 8 Cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-70b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
164
backends/gaudi/server/.gitignore
vendored
Normal file
164
backends/gaudi/server/.gitignore
vendored
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
text_generation_server/__pycache__/
|
||||||
|
text_generation_server/pb/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
transformers
|
||||||
|
safetensors
|
||||||
|
flash-attention/
|
||||||
|
flash-attention-v2/
|
||||||
|
vllm/
|
||||||
|
llm-awq/
|
||||||
|
eetq/
|
||||||
|
mamba/
|
38
backends/gaudi/server/Makefile
Normal file
38
backends/gaudi/server/Makefile
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
include Makefile-flash-att
|
||||||
|
include Makefile-flash-att-v2
|
||||||
|
include Makefile-vllm
|
||||||
|
include Makefile-awq
|
||||||
|
include Makefile-eetq
|
||||||
|
include Makefile-selective-scan
|
||||||
|
|
||||||
|
PROTO_PATH ?= ../proto/v3
|
||||||
|
|
||||||
|
unit-tests:
|
||||||
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
||||||
|
gen-server:
|
||||||
|
# Compile protos
|
||||||
|
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
|
||||||
|
mkdir text_generation_server/pb || true
|
||||||
|
python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
|
||||||
|
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
|
||||||
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
|
install: gen-server
|
||||||
|
pip install pip --upgrade
|
||||||
|
pip install --no-deps -r requirements.txt
|
||||||
|
pip install -e "."
|
||||||
|
|
||||||
|
run-dev:
|
||||||
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
|
install-poetry:
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
update-lock:
|
||||||
|
rm poetry.lock
|
||||||
|
poetry lock --no-update
|
||||||
|
|
||||||
|
export-requirements:
|
||||||
|
poetry export -o requirements.txt --without-hashes
|
15
backends/gaudi/server/Makefile-awq
Normal file
15
backends/gaudi/server/Makefile-awq
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Fork that adds only the correct stream to this kernel in order
|
||||||
|
# to make cuda graphs work.
|
||||||
|
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
|
||||||
|
|
||||||
|
awq:
|
||||||
|
rm -rf llm-awq
|
||||||
|
git clone https://github.com/huggingface/llm-awq
|
||||||
|
|
||||||
|
build-awq: awq
|
||||||
|
cd llm-awq/ && git fetch && git checkout $(awq_commit)
|
||||||
|
cd llm-awq/awq/kernels && python setup.py build
|
||||||
|
|
||||||
|
install-awq: build-awq
|
||||||
|
pip uninstall awq_inference_engine -y || true
|
||||||
|
cd llm-awq/awq/kernels && python setup.py install
|
13
backends/gaudi/server/Makefile-eetq
Normal file
13
backends/gaudi/server/Makefile-eetq
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
|
||||||
|
|
||||||
|
eetq:
|
||||||
|
# Clone eetq
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||||
|
|
||||||
|
build-eetq: eetq
|
||||||
|
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
|
||||||
|
cd eetq && python setup.py build
|
||||||
|
|
||||||
|
install-eetq: build-eetq
|
||||||
|
cd eetq && python setup.py install
|
15
backends/gaudi/server/Makefile-fbgemm
Normal file
15
backends/gaudi/server/Makefile-fbgemm
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
fbgemm_commit := v0.8.0
|
||||||
|
|
||||||
|
build-fbgemm:
|
||||||
|
@if [ ! -d "fbgemm" ]; then \
|
||||||
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
||||||
|
fi
|
||||||
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
cd fbgemm_gpu && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||||
|
|
||||||
|
install-fbgemm: build-fbgemm
|
||||||
|
cd fbgemm/fbgemm_gpu && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
12
backends/gaudi/server/Makefile-flash-att
Normal file
12
backends/gaudi/server/Makefile-flash-att
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||||
|
|
||||||
|
build-flash-attention:
|
||||||
|
if [ ! -d 'flash-attention' ]; then \
|
||||||
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
|
git clone https://github.com/HazyResearch/flash-attention.git; \
|
||||||
|
fi
|
||||||
|
cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
|
||||||
|
MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
|
||||||
|
|
||||||
|
install-flash-attention: build-flash-attention
|
||||||
|
cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
|
21
backends/gaudi/server/Makefile-flash-att-v2
Normal file
21
backends/gaudi/server/Makefile-flash-att-v2
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
|
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||||
|
|
||||||
|
build-flash-attention-v2-cuda:
|
||||||
|
pip install -U packaging wheel
|
||||||
|
pip install flash-attn==$(flash_att_v2_commit_cuda)
|
||||||
|
|
||||||
|
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||||
|
echo "Flash v2 installed"
|
||||||
|
|
||||||
|
build-flash-attention-v2-rocm:
|
||||||
|
if [ ! -d 'flash-attention-v2' ]; then \
|
||||||
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
|
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||||
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||||
|
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||||
|
cd flash-attention-v2 && \
|
||||||
|
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
28
backends/gaudi/server/Makefile-selective-scan
Normal file
28
backends/gaudi/server/Makefile-selective-scan
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
|
||||||
|
|
||||||
|
causal-conv1d:
|
||||||
|
rm -rf causal-conv1d
|
||||||
|
git clone https://github.com/Dao-AILab/causal-conv1d.git
|
||||||
|
|
||||||
|
build-causal-conv1d: causal-conv1d
|
||||||
|
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
|
||||||
|
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
|
||||||
|
|
||||||
|
install-causal-conv1d: build-causal-conv1d
|
||||||
|
pip uninstall causal-conv1d -y || true
|
||||||
|
cd causal-conv1d/ && pip install .
|
||||||
|
|
||||||
|
# selective-scan dependends on causal-conv1d
|
||||||
|
selective-scan:
|
||||||
|
rm -rf mamba
|
||||||
|
git clone https://github.com/state-spaces/mamba.git mamba
|
||||||
|
|
||||||
|
build-selective-scan: selective-scan
|
||||||
|
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
|
||||||
|
cd mamba && python setup.py build
|
||||||
|
|
||||||
|
install-selective-scan: install-causal-conv1d build-selective-scan
|
||||||
|
pip uninstall selective-scan-cuda -y || true
|
||||||
|
cd mamba && pip install .
|
||||||
|
|
||||||
|
build-all: build-causal-conv1d build-selective-scan
|
23
backends/gaudi/server/Makefile-vllm
Normal file
23
backends/gaudi/server/Makefile-vllm
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
|
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||||
|
build-vllm-cuda:
|
||||||
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
git clone https://github.com/Narsil/vllm.git vllm; \
|
||||||
|
fi
|
||||||
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
||||||
|
|
||||||
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
|
||||||
|
|
||||||
|
build-vllm-rocm:
|
||||||
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||||
|
fi
|
||||||
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
|
install-vllm-rocm: build-vllm-rocm
|
||||||
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
|
15
backends/gaudi/server/README.md
Normal file
15
backends/gaudi/server/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Text Generation Inference Python gRPC Server
|
||||||
|
|
||||||
|
A Python gRPC server for Text Generation Inference
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make install
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make run-dev
|
||||||
|
```
|
91
backends/gaudi/server/dill-0.3.7-patch.sh
Normal file
91
backends/gaudi/server/dill-0.3.7-patch.sh
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
|
||||||
|
pushd dill
|
||||||
|
cat <<EOF > dill-0.3.7.patch
|
||||||
|
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||||
|
index d0cf543..f6eb662 100644
|
||||||
|
--- a/dill/_dill.py
|
||||||
|
+++ b/dill/_dill.py
|
||||||
|
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||||
|
XRangeType = range
|
||||||
|
from types import MappingProxyType as DictProxyType, new_class
|
||||||
|
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||||
|
-import __main__ as _main_module
|
||||||
|
+class _LazyMainModule(object):
|
||||||
|
+ _module = None
|
||||||
|
+ @property
|
||||||
|
+ def module(self):
|
||||||
|
+ if self._module is None:
|
||||||
|
+ import __main__ as _m_module
|
||||||
|
+ self._module = _m_module
|
||||||
|
+ return self._module
|
||||||
|
+_main_module = _LazyMainModule()
|
||||||
|
import marshal
|
||||||
|
import gc
|
||||||
|
# import zlib
|
||||||
|
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
|
||||||
|
_fmode = kwds.pop('fmode', None)
|
||||||
|
_recurse = kwds.pop('recurse', None)
|
||||||
|
StockPickler.__init__(self, file, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._diff_cache = {}
|
||||||
|
self._byref = settings['byref'] if _byref is None else _byref
|
||||||
|
self._strictio = False #_strictio
|
||||||
|
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
|
||||||
|
settings = Pickler.settings
|
||||||
|
_ignore = kwds.pop('ignore', None)
|
||||||
|
StockUnpickler.__init__(self, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||||
|
|
||||||
|
def load(self): #NOTE: if settings change, need to update attributes
|
||||||
|
obj = StockUnpickler.load(self)
|
||||||
|
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||||
|
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||||
|
if not self._ignore:
|
||||||
|
# point obj class to main
|
||||||
|
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||||
|
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
|
||||||
|
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||||
|
logger.trace(pickler, "# D1")
|
||||||
|
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||||
|
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||||
|
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||||
|
logger.trace(pickler, "# D3")
|
||||||
|
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||||
|
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||||
|
and type(obj['__name__']) is str \\
|
||||||
|
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||||
|
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||||
|
diff --git a/dill/session.py b/dill/session.py
|
||||||
|
index 74234ab..1be8d89 100644
|
||||||
|
--- a/dill/session.py
|
||||||
|
+++ b/dill/session.py
|
||||||
|
@@ -233,7 +233,7 @@ def dump_module(
|
||||||
|
protocol = settings['protocol']
|
||||||
|
main = module
|
||||||
|
if main is None:
|
||||||
|
- main = _main_module
|
||||||
|
+ main = _main_module.module
|
||||||
|
elif isinstance(main, str):
|
||||||
|
main = _import_module(main)
|
||||||
|
if not isinstance(main, ModuleType):
|
||||||
|
@@ -501,7 +501,7 @@ def load_module(
|
||||||
|
pass
|
||||||
|
assert loaded is main
|
||||||
|
_restore_modules(unpickler, main)
|
||||||
|
- if main is _main_module or main is module:
|
||||||
|
+ if main is _main_module.module or main is module:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return main
|
||||||
|
|
||||||
|
EOF
|
||||||
|
git apply dill-0.3.7.patch
|
||||||
|
python -m pip install .
|
||||||
|
popd
|
||||||
|
rm -fr dill
|
91
backends/gaudi/server/dill-0.3.8-patch.sh
Normal file
91
backends/gaudi/server/dill-0.3.8-patch.sh
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
|
||||||
|
pushd dill
|
||||||
|
cat <<EOF > dill-0.3.8.patch
|
||||||
|
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||||
|
index d42432f..1d251e6 100644
|
||||||
|
--- a/dill/_dill.py
|
||||||
|
+++ b/dill/_dill.py
|
||||||
|
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||||
|
XRangeType = range
|
||||||
|
from types import MappingProxyType as DictProxyType, new_class
|
||||||
|
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||||
|
-import __main__ as _main_module
|
||||||
|
+class _LazyMainModule(object):
|
||||||
|
+ _module = None
|
||||||
|
+ @property
|
||||||
|
+ def module(self):
|
||||||
|
+ if self._module is None:
|
||||||
|
+ import __main__ as _m_module
|
||||||
|
+ self._module = _m_module
|
||||||
|
+ return self._module
|
||||||
|
+_main_module = _LazyMainModule()
|
||||||
|
import marshal
|
||||||
|
import gc
|
||||||
|
# import zlib
|
||||||
|
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
|
||||||
|
_fmode = kwds.pop('fmode', None)
|
||||||
|
_recurse = kwds.pop('recurse', None)
|
||||||
|
StockPickler.__init__(self, file, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._diff_cache = {}
|
||||||
|
self._byref = settings['byref'] if _byref is None else _byref
|
||||||
|
self._strictio = False #_strictio
|
||||||
|
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
|
||||||
|
settings = Pickler.settings
|
||||||
|
_ignore = kwds.pop('ignore', None)
|
||||||
|
StockUnpickler.__init__(self, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||||
|
|
||||||
|
def load(self): #NOTE: if settings change, need to update attributes
|
||||||
|
obj = StockUnpickler.load(self)
|
||||||
|
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||||
|
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||||
|
if not self._ignore:
|
||||||
|
# point obj class to main
|
||||||
|
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||||
|
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
|
||||||
|
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||||
|
logger.trace(pickler, "# D1")
|
||||||
|
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||||
|
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||||
|
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||||
|
logger.trace(pickler, "# D3")
|
||||||
|
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||||
|
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||||
|
and type(obj['__name__']) is str \\
|
||||||
|
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||||
|
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||||
|
diff --git a/dill/session.py b/dill/session.py
|
||||||
|
index e91068a..a921b43 100644
|
||||||
|
--- a/dill/session.py
|
||||||
|
+++ b/dill/session.py
|
||||||
|
@@ -233,7 +233,7 @@ def dump_module(
|
||||||
|
protocol = settings['protocol']
|
||||||
|
main = module
|
||||||
|
if main is None:
|
||||||
|
- main = _main_module
|
||||||
|
+ main = _main_module.module
|
||||||
|
elif isinstance(main, str):
|
||||||
|
main = _import_module(main)
|
||||||
|
if not isinstance(main, ModuleType):
|
||||||
|
@@ -501,7 +501,7 @@ def load_module(
|
||||||
|
pass
|
||||||
|
assert loaded is main
|
||||||
|
_restore_modules(unpickler, main)
|
||||||
|
- if main is _main_module or main is module:
|
||||||
|
+ if main is _main_module.module or main is module:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return main
|
||||||
|
|
||||||
|
EOF
|
||||||
|
git apply dill-0.3.8.patch
|
||||||
|
python -m pip install .
|
||||||
|
popd
|
||||||
|
rm -fr dill
|
@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from test_model import TEST_CONFIGS
|
||||||
|
|
||||||
|
UNKNOWN_CONFIGS = {
|
||||||
|
name: config
|
||||||
|
for name, config in TEST_CONFIGS.items()
|
||||||
|
if config["expected_greedy_output"] == "unknown"
|
||||||
|
or config["expected_batch_output"] == "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
|
||||||
|
def test_config(request) -> Dict[str, Any]:
|
||||||
|
"""Fixture that provides model configurations for testing."""
|
||||||
|
test_config = UNKNOWN_CONFIGS[request.param]
|
||||||
|
test_config["test_name"] = request.param
|
||||||
|
return test_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_name(test_config):
|
||||||
|
yield test_config["test_name"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tgi_service(launcher, test_config, test_name) -> Generator:
|
||||||
|
"""Fixture that provides a TGI service for testing."""
|
||||||
|
with launcher(test_config["model_id"], test_name) as service:
|
||||||
|
yield service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
|
||||||
|
"""Test that captures expected outputs for models with unknown outputs."""
|
||||||
|
print(f"Testing {test_name} with {test_config['model_id']}")
|
||||||
|
|
||||||
|
# Wait for service to be ready
|
||||||
|
await tgi_service.health(1000)
|
||||||
|
client = tgi_service.client
|
||||||
|
|
||||||
|
# Test single request (greedy)
|
||||||
|
print("Testing single request...")
|
||||||
|
response = await client.generate(
|
||||||
|
test_config["input"],
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
greedy_output = response.generated_text
|
||||||
|
|
||||||
|
# Test multiple requests (batch)
|
||||||
|
print("Testing batch requests...")
|
||||||
|
responses = []
|
||||||
|
for _ in range(4):
|
||||||
|
response = await client.generate(
|
||||||
|
test_config["input"],
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
responses.append(response.generated_text)
|
||||||
|
|
||||||
|
# Store results in a JSON file
|
||||||
|
output_file = "server/integration-tests/expected_outputs.json"
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Try to load existing results if file exists
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
with open(output_file, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
|
||||||
|
# Update results for this model
|
||||||
|
results[test_name] = {
|
||||||
|
"model_id": test_config["model_id"],
|
||||||
|
"input": test_config["input"],
|
||||||
|
"greedy_output": greedy_output,
|
||||||
|
"batch_outputs": responses,
|
||||||
|
"args": test_config["args"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save updated results
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\nResults for {test_name} saved to {output_file}")
|
292
backends/gaudi/server/integration-tests/conftest.py
Normal file
292
backends/gaudi/server/integration-tests/conftest.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import List
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import docker
|
||||||
|
import pytest
|
||||||
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
from docker.errors import NotFound
|
||||||
|
from loguru import logger
|
||||||
|
from test_model import TEST_CONFIGS
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import Response
|
||||||
|
|
||||||
|
# Use the latest image from the local docker build
|
||||||
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
|
||||||
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
|
||||||
|
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
HF_TOKEN is not None
|
||||||
|
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
|
||||||
|
|
||||||
|
if DOCKER_VOLUME is None:
|
||||||
|
logger.warning(
|
||||||
|
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
|
||||||
|
|
||||||
|
BASE_ENV = {
|
||||||
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||||
|
"LOG_LEVEL": LOG_LEVEL,
|
||||||
|
"HF_TOKEN": os.getenv("HF_TOKEN", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HABANA_RUN_ARGS = {
|
||||||
|
"runtime": "habana",
|
||||||
|
"ipc_mode": "host",
|
||||||
|
"cap_add": ["sys_nice"],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||||
|
level="INFO",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_container_logs(container, test_name):
|
||||||
|
"""Stream container logs in a separate thread."""
|
||||||
|
try:
|
||||||
|
for log in container.logs(stream=True, follow=True):
|
||||||
|
print(
|
||||||
|
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
|
||||||
|
end="",
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming container logs: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class LauncherHandle:
|
||||||
|
def __init__(self, port: int):
|
||||||
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
|
||||||
|
|
||||||
|
def _inner_health(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def health(self, timeout: int = 60):
|
||||||
|
assert timeout > 0
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(f"Starting health check with timeout of {timeout}s")
|
||||||
|
|
||||||
|
for attempt in range(timeout):
|
||||||
|
if not self._inner_health():
|
||||||
|
logger.error("Launcher crashed during health check")
|
||||||
|
raise RuntimeError("Launcher crashed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.client.generate("test")
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(f"Health check passed after {elapsed:.1f}s")
|
||||||
|
return
|
||||||
|
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||||
|
if attempt == timeout - 1:
|
||||||
|
logger.error(f"Health check failed after {timeout}s: {str(e)}")
|
||||||
|
raise RuntimeError(f"Health check failed: {str(e)}")
|
||||||
|
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
|
||||||
|
logger.debug(
|
||||||
|
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during health check: {str(e)}")
|
||||||
|
# Get full traceback for debugging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerLauncherHandle(LauncherHandle):
|
||||||
|
def __init__(self, docker_client, container_name, port: int):
|
||||||
|
super(ContainerLauncherHandle, self).__init__(port)
|
||||||
|
self.docker_client = docker_client
|
||||||
|
self.container_name = container_name
|
||||||
|
|
||||||
|
def _inner_health(self) -> bool:
|
||||||
|
try:
|
||||||
|
container = self.docker_client.containers.get(self.container_name)
|
||||||
|
status = container.status
|
||||||
|
if status not in ["running", "created"]:
|
||||||
|
logger.warning(f"Container status is {status}")
|
||||||
|
# Get container logs for debugging
|
||||||
|
logs = container.logs().decode("utf-8")
|
||||||
|
logger.debug(f"Container logs:\n{logs}")
|
||||||
|
return status in ["running", "created"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking container health: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessLauncherHandle(LauncherHandle):
|
||||||
|
def __init__(self, process, port: int):
|
||||||
|
super(ProcessLauncherHandle, self).__init__(port)
|
||||||
|
self.process = process
|
||||||
|
|
||||||
|
def _inner_health(self) -> bool:
|
||||||
|
return self.process.poll() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def data_volume():
|
||||||
|
tmpdir = TemporaryDirectory()
|
||||||
|
yield tmpdir.name
|
||||||
|
try:
|
||||||
|
# Cleanup the temporary directory using sudo as it contains root files created by the container
|
||||||
|
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Error cleaning up temporary directory: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def launcher(data_volume):
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def docker_launcher(
|
||||||
|
model_id: str,
|
||||||
|
test_name: str,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Starting docker launcher for model {model_id} and test {test_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get a random available port
|
||||||
|
def get_free_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
s.listen(1)
|
||||||
|
port = s.getsockname()[1]
|
||||||
|
return port
|
||||||
|
|
||||||
|
port = get_free_port()
|
||||||
|
logger.debug(f"Using port {port}")
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
|
||||||
|
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
logger.info(
|
||||||
|
f"Stopping existing container {container_name} for test {test_name}"
|
||||||
|
)
|
||||||
|
container.stop()
|
||||||
|
container.wait()
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling existing container: {str(e)}")
|
||||||
|
|
||||||
|
model_name = next(
|
||||||
|
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
|
||||||
|
|
||||||
|
env = BASE_ENV.copy()
|
||||||
|
|
||||||
|
# Add model_id to env
|
||||||
|
env["MODEL_ID"] = model_id
|
||||||
|
|
||||||
|
# Add env config that is definied in the fixture parameter
|
||||||
|
if "env_config" in TEST_CONFIGS[model_name]:
|
||||||
|
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
|
||||||
|
|
||||||
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||||
|
logger.debug(f"Using volume {volumes}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Creating container with name {container_name}")
|
||||||
|
|
||||||
|
# Log equivalent docker run command for debugging, this is not actually executed
|
||||||
|
container = client.containers.run(
|
||||||
|
DOCKER_IMAGE,
|
||||||
|
command=tgi_args,
|
||||||
|
name=container_name,
|
||||||
|
environment=env,
|
||||||
|
detach=True,
|
||||||
|
volumes=volumes,
|
||||||
|
ports={"80/tcp": port},
|
||||||
|
**HABANA_RUN_ARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Container {container_name} started successfully")
|
||||||
|
|
||||||
|
# Start log streaming in a background thread
|
||||||
|
log_thread = threading.Thread(
|
||||||
|
target=stream_container_logs,
|
||||||
|
args=(container, test_name),
|
||||||
|
daemon=True, # This ensures the thread will be killed when the main program exits
|
||||||
|
)
|
||||||
|
log_thread.start()
|
||||||
|
|
||||||
|
# Add a small delay to allow container to initialize
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Check container status after creation
|
||||||
|
status = container.status
|
||||||
|
logger.debug(f"Initial container status: {status}")
|
||||||
|
if status not in ["running", "created"]:
|
||||||
|
logs = container.logs().decode("utf-8")
|
||||||
|
logger.error(f"Container failed to start properly. Logs:\n{logs}")
|
||||||
|
|
||||||
|
yield ContainerLauncherHandle(client, container.name, port)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting container: {str(e)}")
|
||||||
|
# Get full traceback for debugging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
logger.info(f"Stopping container {container_name}")
|
||||||
|
container.stop()
|
||||||
|
container.wait()
|
||||||
|
|
||||||
|
container_output = container.logs().decode("utf-8")
|
||||||
|
print(container_output, file=sys.stderr)
|
||||||
|
|
||||||
|
container.remove()
|
||||||
|
logger.info(f"Container {container_name} removed successfully")
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning up container: {str(e)}")
|
||||||
|
|
||||||
|
return docker_launcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def generate_load():
|
||||||
|
async def generate_load_inner(
|
||||||
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||||
|
) -> List[Response]:
|
||||||
|
try:
|
||||||
|
futures = [
|
||||||
|
client.generate(
|
||||||
|
prompt,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
for _ in range(n)
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*futures)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating load: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return generate_load_inner
|
2
backends/gaudi/server/integration-tests/pytest.ini
Normal file
2
backends/gaudi/server/integration-tests/pytest.ini
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
7
backends/gaudi/server/integration-tests/requirements.txt
Normal file
7
backends/gaudi/server/integration-tests/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
pytest >= 8.3.5
|
||||||
|
pytest-asyncio >= 0.26.0
|
||||||
|
docker >= 7.1.0
|
||||||
|
Levenshtein >= 0.27.1
|
||||||
|
loguru >= 0.7.3
|
||||||
|
aiohttp >= 3.11.14
|
||||||
|
text-generation
|
276
backends/gaudi/server/integration-tests/test_model.py
Normal file
276
backends/gaudi/server/integration-tests/test_model.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
import pytest
|
||||||
|
from Levenshtein import distance as levenshtein_distance
|
||||||
|
|
||||||
|
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
|
||||||
|
TEST_CONFIGS = {
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct-shared": {
|
||||||
|
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||||||
|
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||||||
|
"args": [
|
||||||
|
"--sharded",
|
||||||
|
"true",
|
||||||
|
"--num-shard",
|
||||||
|
"8",
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"8",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": {
|
||||||
|
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||||
|
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||||
|
"env_config": {},
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf": {
|
||||||
|
"model_id": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||||
|
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.3": {
|
||||||
|
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"bigcode/starcoder2-3b": {
|
||||||
|
"model_id": "bigcode/starcoder2-3b",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"google/gemma-7b-it": {
|
||||||
|
"model_id": "google/gemma-7b-it",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"Qwen/Qwen2-0.5B-Instruct": {
|
||||||
|
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||||
|
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"tiiuae/falcon-7b-instruct": {
|
||||||
|
"model_id": "tiiuae/falcon-7b-instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||||
|
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"microsoft/phi-1_5": {
|
||||||
|
"model_id": "microsoft/phi-1_5",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||||
|
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"openai-community/gpt2": {
|
||||||
|
"model_id": "openai-community/gpt2",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"facebook/opt-125m": {
|
||||||
|
"model_id": "facebook/opt-125m",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||||
|
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"EleutherAI/gpt-j-6b": {
|
||||||
|
"model_id": "EleutherAI/gpt-j-6b",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Testing {len(TEST_CONFIGS)} models")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
|
||||||
|
def test_config(request) -> Dict[str, Any]:
|
||||||
|
"""Fixture that provides model configurations for testing."""
|
||||||
|
test_config = TEST_CONFIGS[request.param]
|
||||||
|
test_config["test_name"] = request.param
|
||||||
|
return test_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_id(test_config):
|
||||||
|
yield test_config["model_id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_name(test_config):
|
||||||
|
yield test_config["test_name"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def expected_outputs(test_config):
|
||||||
|
return {
|
||||||
|
"greedy": test_config["expected_greedy_output"],
|
||||||
|
# "sampling": model_config["expected_sampling_output"],
|
||||||
|
"batch": test_config["expected_batch_output"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def input(test_config):
|
||||||
|
return test_config["input"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tgi_service(launcher, model_id, test_name):
|
||||||
|
with launcher(model_id, test_name) as tgi_service:
|
||||||
|
yield tgi_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def tgi_client(tgi_service) -> AsyncClient:
|
||||||
|
await tgi_service.health(1000)
|
||||||
|
return tgi_service.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_single_request(
|
||||||
|
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
|
||||||
|
):
|
||||||
|
# Bounded greedy decoding without input
|
||||||
|
response = await tgi_client.generate(
|
||||||
|
input,
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 32
|
||||||
|
assert response.generated_text == expected_outputs["greedy"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_multiple_requests(
|
||||||
|
tgi_client, generate_load, expected_outputs, input
|
||||||
|
):
|
||||||
|
num_requests = 4
|
||||||
|
responses = await generate_load(
|
||||||
|
tgi_client,
|
||||||
|
input,
|
||||||
|
max_new_tokens=32,
|
||||||
|
n=num_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
expected = expected_outputs["batch"]
|
||||||
|
for r in responses:
|
||||||
|
assert r.details.generated_tokens == 32
|
||||||
|
# Compute the similarity with the expectation using the levenshtein distance
|
||||||
|
# We should not have more than two substitutions or additions
|
||||||
|
assert levenshtein_distance(r.generated_text, expected) < 3
|
3014
backends/gaudi/server/poetry.lock
generated
Normal file
3014
backends/gaudi/server/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
45
backends/gaudi/server/pyproject.toml
Normal file
45
backends/gaudi/server/pyproject.toml
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "text-generation-server"
|
||||||
|
version = "2.0.4"
|
||||||
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.9,<3.13"
|
||||||
|
protobuf = "^5.0"
|
||||||
|
grpcio = "^1.71.1"
|
||||||
|
grpcio-status = "*"
|
||||||
|
grpcio-reflection = "*"
|
||||||
|
grpc-interceptor = "^0.15.0"
|
||||||
|
typer = "^0.15.0"
|
||||||
|
loguru = "^0.7.3"
|
||||||
|
opentelemetry-api = "^1.32.0"
|
||||||
|
opentelemetry-exporter-otlp = "^1.32.0"
|
||||||
|
opentelemetry-instrumentation-grpc = "^0.53b0"
|
||||||
|
hf-transfer = "^0.1.9"
|
||||||
|
sentencepiece = "^0.2.0"
|
||||||
|
peft = "^0.15"
|
||||||
|
optimum-habana = "1.17"
|
||||||
|
transformers = "^4.49"
|
||||||
|
numpy = "^1.26"
|
||||||
|
accelerate = "^0.33"
|
||||||
|
outlines= { version = "^0.0.36", optional = true }
|
||||||
|
prometheus-client = "^0.21.1"
|
||||||
|
py-cpuinfo = "^9.0.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
grpcio-tools = "*"
|
||||||
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.poetry.requires-plugins]
|
||||||
|
poetry-plugin-export = ">=1.8"
|
101
backends/gaudi/server/requirements.txt
Normal file
101
backends/gaudi/server/requirements.txt
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
|
||||||
|
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
|
@ -0,0 +1,13 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchData,
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdapterBatchData",
|
||||||
|
"AdapterBatchMetadata",
|
||||||
|
]
|
@ -0,0 +1,30 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/config.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import AdapterWeights
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleMap:
|
||||||
|
module_name: str
|
||||||
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterConfig(ABC):
|
||||||
|
base_model_name_or_path: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
pass
|
471
backends/gaudi/server/text_generation_server/adapters/lora.py
Normal file
471
backends/gaudi/server/text_generation_server/adapters/lora.py
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig as _LoraConfig
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
AdapterWeights,
|
||||||
|
BatchAdapterWeights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
BGMV_MAX_RANK,
|
||||||
|
MAX_RANK_CUSTOM,
|
||||||
|
get_tmp_tensors,
|
||||||
|
orient_for_rank,
|
||||||
|
pad_rank,
|
||||||
|
use_cutlass_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
|
block_size = size // world_size
|
||||||
|
start = offset + rank * block_size
|
||||||
|
stop = offset + (rank + 1) * block_size
|
||||||
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def shard_on_dim(
|
||||||
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||||
|
):
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
size = t.shape[dim]
|
||||||
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = t[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = t[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def shard_lora_weights(
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
split_dim: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
# [hidden_size, r]
|
||||||
|
weights_a = [
|
||||||
|
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||||
|
]
|
||||||
|
|
||||||
|
# [r, hidden_size]
|
||||||
|
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||||
|
|
||||||
|
return weights_a, weights_b
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig(AdapterConfig):
|
||||||
|
r: int
|
||||||
|
target_modules: Optional[Union[List[str], str]]
|
||||||
|
fan_in_fan_out: bool
|
||||||
|
lora_alpha: int
|
||||||
|
use_rslora: bool
|
||||||
|
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
adapter_weight_names = set()
|
||||||
|
module_map = {}
|
||||||
|
for weight_name in weight_names:
|
||||||
|
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||||
|
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||||
|
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_map[weight_name] = {
|
||||||
|
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||||
|
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||||
|
}
|
||||||
|
adapter_weight_names.add(lora_a_name)
|
||||||
|
adapter_weight_names.add(lora_b_name)
|
||||||
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
|
return cls(
|
||||||
|
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||||
|
r=hf_config.r,
|
||||||
|
target_modules=hf_config.target_modules,
|
||||||
|
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||||
|
lora_alpha=hf_config.lora_alpha,
|
||||||
|
use_rslora=(
|
||||||
|
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraWeights(AdapterWeights):
|
||||||
|
"""LoRA weights for a single adapter merged across all layers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
adapter_config: LoraConfig,
|
||||||
|
):
|
||||||
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||||
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||||
|
|
||||||
|
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||||
|
self._is_transposed = False
|
||||||
|
|
||||||
|
# [num_layers, hidden_size, r]
|
||||||
|
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||||
|
self._weights_a = torch.stack(weights_a)
|
||||||
|
|
||||||
|
# [num_layers, r, hidden_size]
|
||||||
|
self._weights_b = torch.stack(weights_b)
|
||||||
|
|
||||||
|
self.adapter_config = adapter_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
def _transpose_weights(self):
|
||||||
|
if self._use_cutlass_shrink:
|
||||||
|
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||||
|
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||||
|
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||||
|
self._is_transposed = not self._is_transposed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
# prepare pre-loaded lora weights for use in the model.
|
||||||
|
#
|
||||||
|
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||||
|
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||||
|
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||||
|
# - processes `nlayers` number of layers.
|
||||||
|
# - converts weights to the specified `dtype`.
|
||||||
|
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||||
|
# - maps weights to specific layers using `target_to_layer`.
|
||||||
|
# - tracks `unused_weight_names` to identify any unused weights.
|
||||||
|
#
|
||||||
|
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||||
|
# with SGMV or BGMV operations.
|
||||||
|
@classmethod
|
||||||
|
def prepare_weights(
|
||||||
|
cls,
|
||||||
|
config: LoraConfig,
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
nlayers: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_size: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
lora_a_list = [None] * nlayers
|
||||||
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
for layer_id in range(nlayers):
|
||||||
|
key = (layer_id, layer_type)
|
||||||
|
weight_name, layer = target_to_layer[key]
|
||||||
|
base_weight = layer.base_layer.linear.weight
|
||||||
|
base_device = base_weight.device
|
||||||
|
|
||||||
|
if weight_name not in module_map:
|
||||||
|
# There is no LoRA weight for this layer type in the adapter
|
||||||
|
return None
|
||||||
|
|
||||||
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
|
lora_a = lora_a.to(base_device, dtype)
|
||||||
|
|
||||||
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
|
lora_b = lora_b.to(base_device, dtype)
|
||||||
|
|
||||||
|
scale = get_scaling_factor(
|
||||||
|
config.lora_alpha,
|
||||||
|
config.r,
|
||||||
|
uses_rslora=config.use_rslora,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names.discard(lora_a_name)
|
||||||
|
unused_weight_names.discard(lora_b_name)
|
||||||
|
|
||||||
|
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||||
|
# (A * B) * C = A * (B * C)
|
||||||
|
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||||
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
|
# pad lora ranks to be compatible with sgmv
|
||||||
|
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||||
|
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||||
|
|
||||||
|
if lora_a_list:
|
||||||
|
# update rank if it was padded
|
||||||
|
padded_rank = lora_a_list[0].size(1)
|
||||||
|
config.r = padded_rank
|
||||||
|
|
||||||
|
return LoraWeights(
|
||||||
|
*shard_lora_weights(
|
||||||
|
weights_a=lora_a_list,
|
||||||
|
weights_b=lora_b_list,
|
||||||
|
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||||
|
process_group=process_group,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankSegments:
|
||||||
|
rank: int
|
||||||
|
|
||||||
|
lora_a_ptr: torch.Tensor
|
||||||
|
lora_b_ptr: torch.Tensor
|
||||||
|
|
||||||
|
# prefill (sgmv)
|
||||||
|
tmp_shrink: torch.Tensor
|
||||||
|
tmp_expand: torch.Tensor
|
||||||
|
segment_starts: torch.Tensor
|
||||||
|
segment_ends: torch.Tensor
|
||||||
|
|
||||||
|
# decode (bgmv)
|
||||||
|
indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchLoraWeights(BatchAdapterWeights):
|
||||||
|
lora_a: Dict[int, torch.Tensor]
|
||||||
|
lora_b: Dict[int, torch.Tensor]
|
||||||
|
adapter_index_configs: Dict[int, LoraConfig]
|
||||||
|
rank_data: Dict[int, RankSegments]
|
||||||
|
use_sgmv: bool
|
||||||
|
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
return adapter_index in self.adapter_index_configs
|
||||||
|
|
||||||
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||||
|
return all(
|
||||||
|
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||||
|
for rank_data in self.rank_data.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Optional["BatchLoraWeights"]:
|
||||||
|
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||||
|
adapter_weights = {
|
||||||
|
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||||
|
}
|
||||||
|
if not adapter_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_weights = next(iter(adapter_weights.values()))
|
||||||
|
device = first_weights.weights_a.device
|
||||||
|
segment_indices = meta.segment_indices
|
||||||
|
|
||||||
|
lora_a = {
|
||||||
|
idx: adapter_weights[idx].weights_a
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
lora_b = {
|
||||||
|
idx: adapter_weights[idx].weights_b
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
max_rank = max(
|
||||||
|
(
|
||||||
|
adapter_weights[idx].lora_a_r
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
),
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill or max_rank > BGMV_MAX_RANK:
|
||||||
|
use_sgmv = True
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
use_sgmv = False
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index_configs = {
|
||||||
|
idx: adapter_weights[idx].adapter_config
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||||
|
|
||||||
|
rank_indices = defaultdict(list)
|
||||||
|
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||||
|
if adapter_idx not in adapter_weights:
|
||||||
|
continue
|
||||||
|
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||||
|
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||||
|
for head_index in prefill_head_indices:
|
||||||
|
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||||
|
if head_index < meta.adapter_segments[j]:
|
||||||
|
prefill_head_segment_ends[-1] += 1
|
||||||
|
else:
|
||||||
|
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||||
|
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
rank_data = {}
|
||||||
|
for rank, indices in rank_indices.items():
|
||||||
|
tmp_shrink = None
|
||||||
|
tmp_expand = None
|
||||||
|
segment_starts = None
|
||||||
|
segment_ends = None
|
||||||
|
batch_indices = None
|
||||||
|
|
||||||
|
if use_sgmv:
|
||||||
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||||
|
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||||
|
lora_a_ptr_indices.size(0), rank, device
|
||||||
|
)
|
||||||
|
segment_starts = meta.adapter_segments[indices]
|
||||||
|
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
for i, segment_index in enumerate(indices):
|
||||||
|
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||||
|
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||||
|
else:
|
||||||
|
rank_indices = set(indices)
|
||||||
|
batch_indices = [
|
||||||
|
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||||
|
]
|
||||||
|
batch_indices = [
|
||||||
|
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||||
|
]
|
||||||
|
batch_indices = torch.tensor(
|
||||||
|
batch_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_data[rank] = RankSegments(
|
||||||
|
rank=rank,
|
||||||
|
tmp_shrink=tmp_shrink,
|
||||||
|
tmp_expand=tmp_expand,
|
||||||
|
lora_a_ptr=lora_a_ptr[indices],
|
||||||
|
lora_b_ptr=lora_b_ptr[indices],
|
||||||
|
segment_starts=segment_starts,
|
||||||
|
segment_ends=segment_ends,
|
||||||
|
indices=batch_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchLoraWeights(
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
adapter_index_configs=adapter_index_configs,
|
||||||
|
rank_data=rank_data,
|
||||||
|
use_sgmv=use_sgmv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaling_factor(
|
||||||
|
lora_alpha: int,
|
||||||
|
r: int,
|
||||||
|
uses_rslora: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Computes the scaling factor for the lora weights."""
|
||||||
|
if uses_rslora:
|
||||||
|
return lora_alpha / (r**0.5)
|
||||||
|
return lora_alpha / r
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||||
|
if hasattr(v, "lora_weights"):
|
||||||
|
return v.lora_weights
|
||||||
|
return v
|
146
backends/gaudi/server/text_generation_server/adapters/weights.py
Normal file
146
backends/gaudi/server/text_generation_server/adapters/weights.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractclassmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchMetadata:
|
||||||
|
# [batch_size]
|
||||||
|
adapter_indices: torch.Tensor
|
||||||
|
|
||||||
|
# [num_adapters]
|
||||||
|
adapter_set: Set[int]
|
||||||
|
|
||||||
|
# [num_segments + 1]
|
||||||
|
adapter_segments: torch.Tensor
|
||||||
|
|
||||||
|
# [num_segments]
|
||||||
|
# maps from segment index to adapter index, i.e.:
|
||||||
|
# segment_indices[s] == adapter_indices[i]
|
||||||
|
segment_indices: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speculative_tokens(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchAdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: "AdapterBatchMetadata",
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: torch.Tensor,
|
||||||
|
) -> Optional["BatchAdapterWeights"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LayerAdapterWeights:
|
||||||
|
"""Adapter weights that apply to a particular layer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||||
|
|
||||||
|
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||||
|
self.adapter_weights[adapter_idx] = weights
|
||||||
|
|
||||||
|
def remove_adapter(self, adapter_idx: int):
|
||||||
|
if adapter_idx not in self.adapter_weights:
|
||||||
|
return
|
||||||
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
def get_data(
|
||||||
|
self,
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Dict[str, BatchAdapterWeights]:
|
||||||
|
# bucket adapters by batch class
|
||||||
|
adapter_batch_types: Dict[
|
||||||
|
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||||
|
] = defaultdict(dict)
|
||||||
|
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||||
|
for batch_type in adapter_weights.get_batch_types():
|
||||||
|
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
|
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||||
|
batched_weights = batch_type.load(
|
||||||
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
|
)
|
||||||
|
if batched_weights is not None:
|
||||||
|
batch_data = batched_weights
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchData:
|
||||||
|
meta: AdapterBatchMetadata
|
||||||
|
|
||||||
|
# layer type -> adapter type -> batch weight data
|
||||||
|
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||||
|
|
||||||
|
prefill: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_meta(
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
weights: Dict[str, LayerAdapterWeights],
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> "AdapterBatchData":
|
||||||
|
data = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if v.is_empty():
|
||||||
|
continue
|
||||||
|
data[k] = v.get_data(
|
||||||
|
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||||
|
)
|
||||||
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||||
|
|
||||||
|
def ranks(self) -> Set[int]:
|
||||||
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
|
ranks = set()
|
||||||
|
for lora_data in self.data.values():
|
||||||
|
if lora_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for rank_data in lora_data.rank_data.values():
|
||||||
|
ranks.add(rank_data.rank)
|
||||||
|
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
def layer_names(self) -> Set[str]:
|
||||||
|
return set(self.data.keys())
|
||||||
|
|
||||||
|
def adapter_keys(self) -> Set[str]:
|
||||||
|
adapter_keys = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
adapter_keys.update(layer_data.keys())
|
||||||
|
return adapter_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rank(self) -> int:
|
||||||
|
ranks = self.ranks()
|
||||||
|
return max(ranks) if len(ranks) > 0 else 0
|
34
backends/gaudi/server/text_generation_server/cache.py
Normal file
34
backends/gaudi/server/text_generation_server/cache.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
|
class Cache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: Dict[int, B] = {}
|
||||||
|
|
||||||
|
def pop(self, batch_id: int) -> Optional[B]:
|
||||||
|
return self.cache.pop(batch_id, None)
|
||||||
|
|
||||||
|
def set(self, entry: B):
|
||||||
|
if entry is not None:
|
||||||
|
self.cache[entry.batch_id] = entry
|
||||||
|
|
||||||
|
def delete(self, batch_id: int):
|
||||||
|
batch = self.pop(batch_id)
|
||||||
|
if batch is not None:
|
||||||
|
del batch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
keys = list(self.cache.keys())
|
||||||
|
for k in keys:
|
||||||
|
self.delete(k)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cache.keys())
|
426
backends/gaudi/server/text_generation_server/cli.py
Normal file
426
backends/gaudi/server/text_generation_server/cli.py
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class Quantization(str, Enum):
|
||||||
|
gptq = "gptq"
|
||||||
|
awq = "awq"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
|
class Dtype(str, Enum):
|
||||||
|
float16 = "float16"
|
||||||
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def serve(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
sharded: bool = False,
|
||||||
|
quantize: Optional[Quantization] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
|
dtype: Optional[Dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
otlp_service_name: str = "text-generation-inference.server",
|
||||||
|
max_input_tokens: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if sharded:
|
||||||
|
# assert (
|
||||||
|
# os.getenv("RANK", None) is not None
|
||||||
|
# ), "RANK must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("WORLD_SIZE", None) is not None
|
||||||
|
), "WORLD_SIZE must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_ADDR", None) is not None
|
||||||
|
), "MASTER_ADDR must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here after the logger is added to log potential import exceptions
|
||||||
|
from text_generation_server import server
|
||||||
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
|
||||||
|
# Setup OpenTelemetry distributed tracing
|
||||||
|
if otlp_endpoint is not None:
|
||||||
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||||
|
|
||||||
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
|
# and warn the user
|
||||||
|
if lora_adapters:
|
||||||
|
logger.warning("LoRA adapters enabled (experimental feature).")
|
||||||
|
|
||||||
|
if "CUDA_GRAPHS" in os.environ:
|
||||||
|
logger.warning(
|
||||||
|
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
|
||||||
|
)
|
||||||
|
global CUDA_GRAPHS
|
||||||
|
CUDA_GRAPHS = None
|
||||||
|
|
||||||
|
# Downgrade enum into str for easier management later on
|
||||||
|
quantize = None if quantize is None else quantize.value
|
||||||
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||||||
|
logger.info(f"quantize={quantize}")
|
||||||
|
if dtype is not None and quantize not in {
|
||||||
|
None,
|
||||||
|
"bitsandbytes",
|
||||||
|
"bitsandbytes-nf4",
|
||||||
|
"bitsandbytes-fp4",
|
||||||
|
"gptq",
|
||||||
|
"awq",
|
||||||
|
"fp8",
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||||
|
|
||||||
|
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
|
||||||
|
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||||
|
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
cmd = (
|
||||||
|
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||||
|
)
|
||||||
|
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||||
|
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||||
|
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||||||
|
if speculate is not None:
|
||||||
|
cmd += f"--speculate {speculate}"
|
||||||
|
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||||
|
do_terminate = False
|
||||||
|
current_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
||||||
|
def terminate_handler(sig, frame):
|
||||||
|
nonlocal do_terminate
|
||||||
|
do_terminate = True
|
||||||
|
if callable(current_handler):
|
||||||
|
current_handler(sig, frame)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, terminate_handler)
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
try:
|
||||||
|
if do_terminate:
|
||||||
|
parent = psutil.Process(proc.pid)
|
||||||
|
all_procs = parent.children(recursive=True) + [parent]
|
||||||
|
for p in all_procs:
|
||||||
|
try:
|
||||||
|
p.terminate()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||||
|
for p in alive:
|
||||||
|
p.kill()
|
||||||
|
|
||||||
|
do_terminate = False
|
||||||
|
|
||||||
|
proc.wait(timeout=3)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
finished = True
|
||||||
|
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
if proc.returncode != 0:
|
||||||
|
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||||
|
return proc.returncode
|
||||||
|
else:
|
||||||
|
server.serve(
|
||||||
|
model_id,
|
||||||
|
lora_adapters,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
uds_path,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def download_weights(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
extension: str = ".safetensors",
|
||||||
|
auto_convert: bool = True,
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
merge_lora: bool = False,
|
||||||
|
):
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here after the logger is added to log potential import exceptions
|
||||||
|
from text_generation_server import utils
|
||||||
|
|
||||||
|
# Test if files were already download
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info("Files are already present on the host. " "Skipping download.")
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||||
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if not is_local_model:
|
||||||
|
# TODO: maybe reverse the default value of merge_lora?
|
||||||
|
# currently by default we don't merge the weights with the base model
|
||||||
|
if merge_lora:
|
||||||
|
try:
|
||||||
|
hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
)
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
is_local_model = True
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
utils.peft.download_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
with open(config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
|
if base_model_id and base_model_id != model_id:
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
download_weights(
|
||||||
|
model_id=base_model_id,
|
||||||
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to download weights from the hub
|
||||||
|
try:
|
||||||
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
utils.download_weights(filenames, model_id, revision)
|
||||||
|
# Successfully downloaded weights
|
||||||
|
return
|
||||||
|
|
||||||
|
# No weights found on the hub with this extension
|
||||||
|
except utils.EntryNotFoundError as e:
|
||||||
|
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
|
||||||
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
elif (Path(model_id) / "adapter_config.json").exists():
|
||||||
|
# Try to load as a local PEFT model
|
||||||
|
try:
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
elif (Path(model_id) / "config.json").exists():
|
||||||
|
# Try to load as a local Medusa model
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = Path(model_id) / "config.json"
|
||||||
|
with open(config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
|
if base_model_id:
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
download_weights(
|
||||||
|
model_id=base_model_id,
|
||||||
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to see if there are local pytorch weights
|
||||||
|
try:
|
||||||
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
|
try:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
|
except Exception:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".pt")
|
||||||
|
|
||||||
|
# No local pytorch weights
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
if extension == ".safetensors":
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Downloading PyTorch weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to see if there are pytorch weights on the hub
|
||||||
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
|
# Download pytorch weights
|
||||||
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
|
||||||
|
if auto_convert:
|
||||||
|
if not trust_remote_code:
|
||||||
|
logger.warning(
|
||||||
|
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||||||
|
"Pickle files are unsafe and can essentially contain remote code execution!"
|
||||||
|
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights to safetensors."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safetensors final filenames
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
|
for p in local_pt_files
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
import json
|
||||||
|
|
||||||
|
if is_local_model:
|
||||||
|
config_filename = os.path.join(model_id, "config.json")
|
||||||
|
else:
|
||||||
|
config_filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
with open(config_filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
architecture = config["architectures"][0]
|
||||||
|
|
||||||
|
class_ = getattr(transformers, architecture)
|
||||||
|
|
||||||
|
# Name for this varible depends on transformers version.
|
||||||
|
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
discard_names = []
|
||||||
|
# Convert pytorch weights to safetensors
|
||||||
|
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def quantize(
|
||||||
|
model_id: str,
|
||||||
|
output_dir: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
upload_to_model_id: Optional[str] = None,
|
||||||
|
percdamp: float = 0.01,
|
||||||
|
act_order: bool = False,
|
||||||
|
groupsize: int = 128,
|
||||||
|
):
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
download_weights(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.gptq.quantize import quantize
|
||||||
|
|
||||||
|
quantize(
|
||||||
|
model_id=model_id,
|
||||||
|
bits=4,
|
||||||
|
groupsize=groupsize,
|
||||||
|
output_dir=output_dir,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
upload_to_model_id=upload_to_model_id,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
|
sym=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||||
|
is_quantization_enabled = quant_config != ""
|
||||||
|
|
||||||
|
if is_quantization_enabled:
|
||||||
|
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||||
|
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||||
|
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||||
|
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||||
|
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||||
|
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_scoped_linear_all_reduce(model):
|
||||||
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
|
from optimum.habana.transformers.models.modeling_all_models import (
|
||||||
|
ScopedLinearAllReduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if type(module) is LinearAllreduce:
|
||||||
|
SL = ScopedLinearAllReduce(mod=module)
|
||||||
|
setattr(model, name, SL)
|
||||||
|
patch_scoped_linear_all_reduce(module)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
|
htorch.core.quantization._check_params_as_const(model)
|
||||||
|
htorch.core.hpu_initialize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_for_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
if model.config.model_type in [
|
||||||
|
"llama",
|
||||||
|
"falcon",
|
||||||
|
"qwen2",
|
||||||
|
"starcoder2",
|
||||||
|
"gemma",
|
||||||
|
]:
|
||||||
|
patch_scoped_linear_all_reduce(model)
|
||||||
|
from neural_compressor.torch.quantization import FP8Config, convert
|
||||||
|
|
||||||
|
config = FP8Config.from_json_file(quant_config)
|
||||||
|
model = convert(model, config)
|
||||||
|
return model
|
45
backends/gaudi/server/text_generation_server/interceptor.py
Normal file
45
backends/gaudi/server/text_generation_server/interceptor.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from google.rpc import status_pb2, code_pb2
|
||||||
|
from grpc_status import rpc_status
|
||||||
|
from grpc_interceptor.server import AsyncServerInterceptor
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Callable, Any
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
async def intercept(
|
||||||
|
self,
|
||||||
|
method: Callable,
|
||||||
|
request_or_iterator: Any,
|
||||||
|
context: grpc.ServicerContext,
|
||||||
|
method_name: str,
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
response = method(request_or_iterator, context)
|
||||||
|
return await response
|
||||||
|
except Exception as err:
|
||||||
|
trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
|
||||||
|
method_name = method_name.split("/")[-1]
|
||||||
|
logger.exception(f"Method {method_name} encountered an error.")
|
||||||
|
|
||||||
|
# Runtime Error cannot be recovered from
|
||||||
|
if isinstance(err, RuntimeError):
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
from .utils.debug import dbg_trace
|
||||||
|
|
||||||
|
dbg_trace("EXCEPTION", traceback.format_exc())
|
||||||
|
await context.abort_with_status(
|
||||||
|
rpc_status.to_status(
|
||||||
|
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,34 @@
|
|||||||
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.linear import (
|
||||||
|
get_linear,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.speculative import SpeculativeHead
|
||||||
|
|
||||||
|
# Just to add the `load` methods.
|
||||||
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
|
||||||
|
from text_generation_server.layers.lora import (
|
||||||
|
LoraLinear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_linear",
|
||||||
|
"FastLinear",
|
||||||
|
"TensorParallelColumnLinear",
|
||||||
|
"TensorParallelRowLinear",
|
||||||
|
"TensorParallelEmbedding",
|
||||||
|
"SpeculativeHead",
|
||||||
|
"LoraLinear",
|
||||||
|
"TensorParallelMultiAdapterLinear",
|
||||||
|
"TensorParallelAdapterRowLinear",
|
||||||
|
"load_layer_norm",
|
||||||
|
"load_conv2d",
|
||||||
|
]
|
@ -0,0 +1,28 @@
|
|||||||
|
from .common import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
trim_attn_metadata,
|
||||||
|
trim_seqlen_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .hpu import (
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
|
from .kv_cache import KVCache, get_kv_scales
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"attention",
|
||||||
|
"get_kv_scales",
|
||||||
|
"paged_attention",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"KVCache",
|
||||||
|
"Seqlen",
|
||||||
|
"HPUPagedAttentionMetadata",
|
||||||
|
"trim_seqlen_metadata",
|
||||||
|
"trim_attn_metadata",
|
||||||
|
]
|
@ -0,0 +1,147 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
import collections
|
||||||
|
|
||||||
|
_TYPE_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HPUPagedAttentionMetadata:
|
||||||
|
"""Metadata for PagedAttention."""
|
||||||
|
|
||||||
|
block_list: Optional[torch.Tensor]
|
||||||
|
block_mapping: Optional[torch.Tensor]
|
||||||
|
block_usage: Optional[torch.Tensor]
|
||||||
|
block_scales: Optional[torch.Tensor]
|
||||||
|
block_groups: Optional[torch.Tensor]
|
||||||
|
attn_bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def subtuple(
|
||||||
|
obj: object,
|
||||||
|
typename: str,
|
||||||
|
to_copy: List[str],
|
||||||
|
to_override: Optional[Dict[str, object]] = None,
|
||||||
|
):
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if to_override is None:
|
||||||
|
to_override = {}
|
||||||
|
fields = set(to_copy) | set(to_override.keys())
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
values = {key: obj[key] for key in fields if key in obj}
|
||||||
|
else:
|
||||||
|
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||||
|
if typename not in _TYPE_CACHE:
|
||||||
|
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
|
||||||
|
return _TYPE_CACHE[typename](**values)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||||
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
|
|
||||||
|
# Before you put more keys in here, make sure you know their
|
||||||
|
# value type and make sure you know how it's going to be hashed.
|
||||||
|
# You can find that information in input_hash function
|
||||||
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedAttentionMetadata",
|
||||||
|
[
|
||||||
|
"block_list",
|
||||||
|
"block_mapping",
|
||||||
|
"block_usage",
|
||||||
|
"block_scales",
|
||||||
|
"block_groups",
|
||||||
|
"attn_bias",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
cache_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_lengths,
|
||||||
|
cache_lengths,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
):
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.cache_lengths = cache_lengths
|
||||||
|
device = self.input_lengths.device
|
||||||
|
shape = self.input_lengths.shape
|
||||||
|
if cu_seqlen_q is None:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
shape[0] + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
|
# Although FA2 might not want the clamping
|
||||||
|
# cu_seqlen_k[0] = 0
|
||||||
|
total = self.input_lengths + self.cache_lengths
|
||||||
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
|
||||||
|
def clamp(self, max):
|
||||||
|
# Flash decoding doesn't need to clamp
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||||
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
|
|
||||||
|
# Before you put more keys in here, make sure you know their
|
||||||
|
# value type and make sure you know how it's going to be hashed.
|
||||||
|
# You can find that information in input_hash function
|
||||||
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedSeqlen",
|
||||||
|
[
|
||||||
|
"input_lengths",
|
||||||
|
"cache_lengths",
|
||||||
|
"cu_seqlen_q",
|
||||||
|
"cu_seqlen_k",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
|
from vllm_hpu_extension import ops
|
||||||
|
from vllm_hpu_extension.utils import Matmul
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
import os
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_cache(cache, blocks):
|
||||||
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||||
|
return cache[: blocks.size(0)]
|
||||||
|
else:
|
||||||
|
return cache.index_select(0, blocks)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
*,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
):
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
bs = seqlen.input_lengths.shape[0]
|
||||||
|
_, head_num, head_size = query.shape
|
||||||
|
_, kv_head_num, head_size = key.shape
|
||||||
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||||
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=softmax_scale,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=seqlen.input_lengths,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
|
):
|
||||||
|
batch_size, head_num, head_size = query.shape
|
||||||
|
output = ops.flat_pa(
|
||||||
|
query=query.view(batch_size, 1, head_num * head_size),
|
||||||
|
key_cache=kv_cache.key,
|
||||||
|
value_cache=kv_cache.value,
|
||||||
|
block_list=hpu_attention_meta.block_list,
|
||||||
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
|
block_scales=hpu_attention_meta.block_scales,
|
||||||
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
scale=softmax_scale,
|
||||||
|
matmul_qk_op=Matmul(),
|
||||||
|
matmul_av_op=Matmul(),
|
||||||
|
batch2block_matmul_op=Matmul(),
|
||||||
|
block2batch_matmul_op=Matmul(),
|
||||||
|
keys_fetch_func=fetch_from_cache,
|
||||||
|
values_fetch_func=fetch_from_cache,
|
||||||
|
)
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(batch_size, head_num, head_size)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
]
|
@ -0,0 +1,139 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from vllm_hpu_extension import cache_ops
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVScales:
|
||||||
|
"""
|
||||||
|
Key-value scales for FP8 KV cache.
|
||||||
|
|
||||||
|
This data class stores key and value scales both as a GPU tensor and
|
||||||
|
as a GPU float. This inconvenience is necessary because some functions
|
||||||
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||||
|
(e.g. flashinfer) take scales as a CPU scalar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scale: torch.Tensor
|
||||||
|
value_scale: torch.Tensor
|
||||||
|
key_scale_cpu: float = field(init=False)
|
||||||
|
value_scale_cpu: float = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||||
|
raise ValueError("Key and value scales must be scalar tensors.")
|
||||||
|
|
||||||
|
self.key_scale_cpu = self.key_scale.item()
|
||||||
|
self.value_scale_cpu = self.value_scale.item()
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""
|
||||||
|
Key-value cache for attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Construct the key-value cache for a layer."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache[0].dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the key cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Get the value cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[1]
|
||||||
|
|
||||||
|
def store(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
):
|
||||||
|
"""Store the key and value at the given slots."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
key_cache = self.kv_cache[0]
|
||||||
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
paged_reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
block_idx = slots // BLOCK_SIZE
|
||||||
|
block_offset = slots % BLOCK_SIZE
|
||||||
|
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||||
|
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
"""Load KV cache scales."""
|
||||||
|
|
||||||
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||||
|
value_scale = key_scale
|
||||||
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||||
|
f"{prefix}.v_scale"
|
||||||
|
):
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||||
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||||
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||||
|
# Fall back to older more coarse-grained scale when available.
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||||
|
value_scale = key_scale
|
||||||
|
|
||||||
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||||
|
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of packing, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||||
|
|
||||||
|
qmatrix = qmatrix.to(torch.int32)
|
||||||
|
|
||||||
|
return qmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
qmatrix (torch.Tensor): matrix of packed integers
|
||||||
|
direction (str): direction of unpacking, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, :, None], shifts[None, None, :]
|
||||||
|
).view(qmatrix.shape[0], -1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, None, :], shifts[None, :, None]
|
||||||
|
).view(-1, qmatrix.shape[-1])
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(
|
||||||
|
imatrix: torch.Tensor,
|
||||||
|
direction: str = "column",
|
||||||
|
order: List[int] = AWQ_PACK_ORDER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies the order to a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of applying order, either "column" or "row"
|
||||||
|
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
|
# awq uses column packing for both weights and zeros
|
||||||
|
izeros = unpack(qzeros, direction="column")
|
||||||
|
iweights = unpack(qweight, direction="column")
|
||||||
|
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
|
izeros = izeros - 1
|
||||||
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
|
qzeros = pack(izeros, direction="column")
|
||||||
|
qweight = pack(iweights, direction="row")
|
||||||
|
|
||||||
|
return qweight, qzeros
|
@ -0,0 +1,3 @@
|
|||||||
|
from .hpu import WQLinear
|
||||||
|
|
||||||
|
__all__ = ["WQLinear"]
|
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import habana_frameworks.torch.hpu # noqa: F401
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||||
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
iweights = iweights.view(iweights.shape[0], -1)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
if qzeros is not None:
|
||||||
|
izeros = torch.bitwise_right_shift(
|
||||||
|
qzeros[:, :, None], shifts[None, None, :]
|
||||||
|
).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
izeros = izeros.view(izeros.shape[0], -1)
|
||||||
|
else:
|
||||||
|
izeros = qzeros
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||||
|
reverse_order_tensor = torch.arange(
|
||||||
|
iweights.shape[-1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=izeros.device,
|
||||||
|
)
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||||
|
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||||
|
|
||||||
|
if izeros is not None:
|
||||||
|
izeros = izeros[:, reverse_order_tensor]
|
||||||
|
iweights = iweights[:, reverse_order_tensor]
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_weight_and_zeros(qweight, qzeros, bits):
|
||||||
|
# Unpack the qweight and qzeros tensors
|
||||||
|
iweight, izeros = unpack_awq(qweight, qzeros, bits)
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
|
||||||
|
|
||||||
|
# overflow checks
|
||||||
|
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||||
|
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||||
|
|
||||||
|
return iweight, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros(
|
||||||
|
(normal.shape[0], normal.shape[1] // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class WQLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w_bit not in [4]:
|
||||||
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
|
self.in_features = qweight.shape[0]
|
||||||
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
self.w_bit = w_bit
|
||||||
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
|
# quick sanity check (make sure aligment)
|
||||||
|
assert self.in_features % self.group_size == 0
|
||||||
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
device = self.qweight.device
|
||||||
|
weight, zeros = unpack_weight_and_zeros(
|
||||||
|
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
|
||||||
|
)
|
||||||
|
self.qweight = pack_tensor(weight).to(device)
|
||||||
|
self.qzeros = pack_tensor(zeros).to(device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
outputs = torch.matmul(x, weights)
|
||||||
|
|
||||||
|
outputs = outputs + self.bias if self.bias is not None else outputs
|
||||||
|
outputs = outputs.reshape(out_shape)
|
||||||
|
return outputs
|
124
backends/gaudi/server/text_generation_server/layers/bnb.py
Normal file
124
backends/gaudi/server/text_generation_server/layers/bnb.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBWeight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear8bitLt(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=0.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
not memory_efficient_backward
|
||||||
|
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
# Necessary for stacked layers
|
||||||
|
self.state.threshold = threshold
|
||||||
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
self.weight = Int8Params(
|
||||||
|
weight.data,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
requires_grad=has_fp16_weights,
|
||||||
|
)
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def init_8bit_state(self):
|
||||||
|
self.state.CB = self.weight.CB
|
||||||
|
self.state.SCB = self.weight.SCB
|
||||||
|
self.weight.CB = None
|
||||||
|
self.weight.SCB = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
if self.weight.CB is not None:
|
||||||
|
self.init_8bit_state()
|
||||||
|
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
|
if not self.state.has_fp16_weights:
|
||||||
|
if self.state.CB is not None and self.state.CxB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del self.state.CB
|
||||||
|
self.weight.data = self.state.CxB
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBFP4Weight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear4bit(self.weight, bias, quant_type="fp4")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBNF4Weight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear4bit(self.weight, bias, quant_type="nf4")
|
||||||
|
|
||||||
|
|
||||||
|
class Linear4bit(torch.nn.Module):
|
||||||
|
def __init__(self, weight, bias, quant_type):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = Params4bit(
|
||||||
|
weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
compress_statistics=True,
|
||||||
|
quant_type=quant_type,
|
||||||
|
)
|
||||||
|
self.compute_dtype = None
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if getattr(self.weight, "quant_state", None) is None:
|
||||||
|
print(
|
||||||
|
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||||
|
)
|
||||||
|
inp_dtype = x.dtype
|
||||||
|
if self.compute_dtype is not None:
|
||||||
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
|
out = bnb.matmul_4bit(
|
||||||
|
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.to(inp_dtype)
|
||||||
|
|
||||||
|
return out
|
41
backends/gaudi/server/text_generation_server/layers/conv.py
Normal file
41
backends/gaudi/server/text_generation_server/layers/conv.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from accelerate import init_empty_weights
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
with init_empty_weights():
|
||||||
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight = torch.nn.Parameter(weight)
|
||||||
|
conv2d.bias = torch.nn.Parameter(bias)
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_conv2d_no_bias(
|
||||||
|
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||||
|
):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight = torch.nn.Parameter(weight)
|
||||||
|
conv2d.bias = None
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
torch.nn.Conv2d.load = load_conv2d
|
||||||
|
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
|
78
backends/gaudi/server/text_generation_server/layers/exl2.py
Normal file
78
backends/gaudi/server/text_generation_server/layers/exl2.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Exl2Weight(Weight):
|
||||||
|
"""
|
||||||
|
Exllama2 exl2 quantized weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
q_weight: torch.Tensor
|
||||||
|
q_scale: torch.Tensor
|
||||||
|
q_invperm: torch.Tensor
|
||||||
|
q_scale_max: torch.Tensor
|
||||||
|
q_groups: torch.Tensor
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.q_scale_max /= 256
|
||||||
|
self.q_invperm = self.q_invperm.short()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.q_weight.device
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||||
|
|
||||||
|
return ExllamaQuantLinear(self, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class Exl2WeightsLoader(WeightsLoader):
|
||||||
|
"""Loader for exl2-quantized weights."""
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||||
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||||
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||||
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||||
|
|
||||||
|
return Exl2Weight(
|
||||||
|
q_weight=q_weight,
|
||||||
|
q_scale=q_scale,
|
||||||
|
q_invperm=q_invperm,
|
||||||
|
q_scale_max=q_scale_max,
|
||||||
|
q_groups=q_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||||
|
|
||||||
|
def get_weights_col(self, weights: Weights, prefix: str):
|
||||||
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
|
return self.get_weights(weights, prefix)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
|
return self.get_weights(weights, prefix)
|
458
backends/gaudi/server/text_generation_server/layers/fp8.py
Normal file
458
backends/gaudi/server/text_generation_server/layers/fp8.py
Normal file
@ -0,0 +1,458 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
WeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||||
|
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp
|
||||||
|
|
||||||
|
w8a8_block_fp8_matmul = None
|
||||||
|
per_token_group_quant_fp8 = None
|
||||||
|
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||||
|
"""
|
||||||
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
|
"""
|
||||||
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_native_float8(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
inv_scale: Union[float, torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = tensor.device
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||||
|
# dequant on cpu to avoid nan on gaudi2
|
||||||
|
tensor = tensor.to("cpu")
|
||||||
|
|
||||||
|
fake_qweight = tensor.to(dtype).to(device)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_quantize(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
scalar: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||||
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||||
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
|
be used without modification).
|
||||||
|
"""
|
||||||
|
shape = weight.shape
|
||||||
|
qweight, scale = scaled_fp8_quant(
|
||||||
|
weight.reshape(-1, shape[-1]),
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=scale_upper_bound,
|
||||||
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
|
use_per_token_if_dynamic=not scalar,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
|
|
||||||
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scale_ub: Optional[float],
|
||||||
|
to_fp8: bool,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
self.to_fp8 = to_fp8
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=dim)
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fp8Weight(Weight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
|
activation_scale_ub: Optional[float] = None
|
||||||
|
force_w8a16: bool = False
|
||||||
|
weight_block_size: Optional[List[int]] = None
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
if self.weight_scale is None:
|
||||||
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
|
||||||
|
self.weight, bias, self.dtype
|
||||||
|
)
|
||||||
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
|
||||||
|
weight=self.weight,
|
||||||
|
scale=self.weight_scale,
|
||||||
|
dtype=self.dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=self.input_scale,
|
||||||
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8Linear(torch.nn.Module):
|
||||||
|
_device_identity_cache = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qweight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[float] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.qweight = qweight
|
||||||
|
self.scale = scale.float()
|
||||||
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight, scalar=True)
|
||||||
|
return cls(
|
||||||
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
qweight=weight,
|
||||||
|
scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
scale_upper_bound=scale_upper_bound,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_device_identity(cls, device):
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
if device not in cls._device_identity_cache:
|
||||||
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||||
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
# https://arxiv.org/pdf/2412.19437
|
||||||
|
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||||
|
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||||
|
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||||
|
# channels).
|
||||||
|
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||||
|
output = w8a8_block_fp8_matmul(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
self.weight_block_size,
|
||||||
|
output_dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output.to(dtype=input.dtype)
|
||||||
|
|
||||||
|
qinput, scale = fp8_quantize(
|
||||||
|
input,
|
||||||
|
self.input_scale,
|
||||||
|
scale_upper_bound=self.scale_upper_bound,
|
||||||
|
scalar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
self.qweight.t(),
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
scale_a=scale,
|
||||||
|
scale_b=self.scale,
|
||||||
|
bias=self.bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
return scale.reshape(-1)
|
@ -0,0 +1,357 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
from .hpu import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQWeight(Weight):
|
||||||
|
qweight: torch.Tensor
|
||||||
|
qzeros: torch.Tensor
|
||||||
|
scales: torch.Tensor
|
||||||
|
g_idx: Optional[torch.Tensor]
|
||||||
|
bits: int
|
||||||
|
groupsize: int
|
||||||
|
use_awq_kernel: bool
|
||||||
|
use_exllama: bool
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.scales.dtype == torch.float:
|
||||||
|
self.scales = self.scales.half()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.qweight.device
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
if self.use_awq_kernel:
|
||||||
|
try:
|
||||||
|
from text_generation_server.layers.awq.quantize import WQLinear
|
||||||
|
|
||||||
|
return WQLinear(
|
||||||
|
w_bit=self.bits,
|
||||||
|
group_size=self.groupsize,
|
||||||
|
qweight=self.qweight,
|
||||||
|
qzeros=self.qzeros,
|
||||||
|
scales=self.scales,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return QuantLinear(
|
||||||
|
self.qweight,
|
||||||
|
self.qzeros,
|
||||||
|
self.scales,
|
||||||
|
self.g_idx,
|
||||||
|
bias,
|
||||||
|
self.bits,
|
||||||
|
self.groupsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQWeightsLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for GPTQ- and AWQ-quantized weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bits: int,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
):
|
||||||
|
self.bits = bits
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quantize = quantize
|
||||||
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if use_exllama and g_idx is not None:
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
qweight = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
scales = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
desc_act = self.desc_act
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
if weights.process_group.size() > 1:
|
||||||
|
if g_idx is not None:
|
||||||
|
if (
|
||||||
|
not torch.equal(
|
||||||
|
# Remove g_idx[0] to adapt the check with TP>1.
|
||||||
|
(g_idx - g_idx[0]).cpu(),
|
||||||
|
torch.tensor(
|
||||||
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and not (g_idx == 0).all()
|
||||||
|
):
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
use_exllama = False
|
||||||
|
desc_act = True
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import (
|
||||||
|
GPTQWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not desc_act and self.groupsize != -1:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
if g_idx is not None:
|
||||||
|
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_gptq_params(self, weights: Weights):
|
||||||
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
self.desc_act = False
|
||||||
|
# `server quantize` used asymmetric quantization unconditionally
|
||||||
|
# before the `gptq_sym` setting tensor was added.
|
||||||
|
self.sym = (
|
||||||
|
weights.get_tensor("gptq_sym").item()
|
||||||
|
if weights.has_tensor("gptq_sym")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.quant_method = "gptq"
|
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("qweight", qweight)
|
||||||
|
self.register_buffer("qzeros", qzeros)
|
||||||
|
self.register_buffer("scales", scales)
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
if bias is not None:
|
||||||
|
self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
self.bits = bits
|
||||||
|
self.maxq = 2**self.bits - 1
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.infeatures = qweight.shape[0] * 32 // bits
|
||||||
|
self.wf = torch.tensor(
|
||||||
|
list(range(0, 32, self.bits)), dtype=torch.int32
|
||||||
|
).unsqueeze(0)
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def unpack_zeros_from_cuda_old_format(self):
|
||||||
|
zeros = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||||
|
self.wf.unsqueeze(0),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
|
||||||
|
zeros = zeros + 1
|
||||||
|
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
|
||||||
|
self.scales.dtype
|
||||||
|
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
|
||||||
|
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
|
||||||
|
return zeros
|
||||||
|
|
||||||
|
def unpack_weight_from_cuda_old_format(self):
|
||||||
|
weight = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||||
|
self.wf.unsqueeze(-1),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
|
||||||
|
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
orig_device = self.qweight.device
|
||||||
|
self.qweight = self.qweight.cpu()
|
||||||
|
weight = self.unpack_weight_from_cuda_old_format()
|
||||||
|
new_qweight = pack_tensor(weight)
|
||||||
|
self.qweight = new_qweight.to(orig_device)
|
||||||
|
# TODO: Support group indexing and remove the check
|
||||||
|
columns = self.qweight.shape[0]
|
||||||
|
g_idx_trivial = [i // self.groupsize for i in range(columns)]
|
||||||
|
g_idx_trivial = torch.tensor(
|
||||||
|
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
self.g_idx, g_idx_trivial
|
||||||
|
), "Non-trivial tensor g_idx is not supported"
|
||||||
|
self.qzeros = self.qzeros.cpu()
|
||||||
|
zeros = self.unpack_zeros_from_cuda_old_format()
|
||||||
|
new_qzeros = pack_tensor(zeros)
|
||||||
|
self.qzeros = new_qzeros.to(orig_device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||||
|
qzeros = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
scales = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||||
|
)
|
||||||
|
g_idx = torch.tensor(
|
||||||
|
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
|
|
||||||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||||
|
|
||||||
|
scales = scales.t().contiguous()
|
||||||
|
zeros = zeros.t().contiguous()
|
||||||
|
scale_zeros = zeros * scales
|
||||||
|
self.scales = scales.clone().half()
|
||||||
|
if linear.bias is not None:
|
||||||
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
|
intweight = []
|
||||||
|
for idx in range(self.infeatures):
|
||||||
|
intweight.append(
|
||||||
|
torch.round(
|
||||||
|
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||||
|
/ self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
|
intweight = torch.cat(intweight, dim=1)
|
||||||
|
intweight = intweight.t().contiguous()
|
||||||
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
qweight = np.zeros(
|
||||||
|
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
row = 0
|
||||||
|
while row < qweight.shape[0]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
row += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = qweight.astype(np.int32)
|
||||||
|
self.qweight = torch.from_numpy(qweight)
|
||||||
|
|
||||||
|
zeros -= 1
|
||||||
|
zeros = zeros.numpy().astype(np.uint32)
|
||||||
|
qzeros = np.zeros(
|
||||||
|
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < qzeros.shape[1]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
col += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qzeros = qzeros.astype(np.int32)
|
||||||
|
self.qzeros = torch.from_numpy(qzeros)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
out = torch.matmul(x, weight)
|
||||||
|
out = out.reshape(out_shape)
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out
|
1026
backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
Normal file
1026
backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
|
||||||
|
def torch_snr_error(
|
||||||
|
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute SNR between y_pred(tensor) and y_real(tensor)
|
||||||
|
|
||||||
|
SNR can be calcualted as following equation:
|
||||||
|
|
||||||
|
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
|
||||||
|
|
||||||
|
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
|
||||||
|
|
||||||
|
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred (torch.Tensor): _description_
|
||||||
|
y_real (torch.Tensor): _description_
|
||||||
|
reduction (str, optional): _description_. Defaults to 'mean'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
ValueError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: _description_
|
||||||
|
"""
|
||||||
|
if y_pred.shape != y_real.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Can not compute snr loss for tensors with different shape. "
|
||||||
|
f"({y_pred.shape} and {y_real.shape})"
|
||||||
|
)
|
||||||
|
reduction = str(reduction).lower()
|
||||||
|
|
||||||
|
if y_pred.ndim == 1:
|
||||||
|
y_pred = y_pred.unsqueeze(0)
|
||||||
|
y_real = y_real.unsqueeze(0)
|
||||||
|
|
||||||
|
y_pred = y_pred.flatten(start_dim=1)
|
||||||
|
y_real = y_real.flatten(start_dim=1)
|
||||||
|
|
||||||
|
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
|
||||||
|
signal_power = torch.pow(y_real, 2).sum(dim=-1)
|
||||||
|
snr = (noise_power) / (signal_power + 1e-7)
|
||||||
|
|
||||||
|
if reduction == "mean":
|
||||||
|
return torch.mean(snr)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return torch.sum(snr)
|
||||||
|
elif reduction == "none":
|
||||||
|
return snr
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported reduction method.")
|
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
|
# Monkey patching
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = torch.nn.Parameter(weight)
|
||||||
|
ln.bias = torch.nn.Parameter(bias)
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = torch.nn.Parameter(weight)
|
||||||
|
ln.bias = None
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
|
|
||||||
|
class FastRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
return cls(weight, eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
from vllm_hpu_extension.kernels import rms_norm
|
||||||
|
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
if residual is not None:
|
||||||
|
residual += hidden_states.view(residual.shape)
|
||||||
|
else:
|
||||||
|
residual = hidden_states
|
||||||
|
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||||
|
if len(orig_shape) == 2:
|
||||||
|
residual = residual.unsqueeze(0)
|
||||||
|
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||||
|
return x.view(orig_shape), residual.view(orig_shape)
|
@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FastLinear(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(weight, bias)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear(weight, bias):
|
||||||
|
# Weights that are loaded through methods that are not
|
||||||
|
# quantization-aware are still bare tensors. We may want
|
||||||
|
# to change this in the future.
|
||||||
|
if isinstance(weight, torch.Tensor):
|
||||||
|
return FastLinear(weight, bias)
|
||||||
|
|
||||||
|
return weight.get_linear(bias)
|
279
backends/gaudi/server/text_generation_server/layers/lora.py
Normal file
279
backends/gaudi/server/text_generation_server/layers/lora.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
add_lora_a_bgmv,
|
||||||
|
add_lora_b_bgmv,
|
||||||
|
has_sgmv,
|
||||||
|
lora_a_sgmv_cutlass,
|
||||||
|
lora_b_sgmv_cutlass,
|
||||||
|
orient_for_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward_layer_type(
|
||||||
|
self,
|
||||||
|
result: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
adapter_data: "AdapterBatchData",
|
||||||
|
layer_type: str,
|
||||||
|
start_idx: int,
|
||||||
|
end_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if adapter_data is None:
|
||||||
|
return result
|
||||||
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
|
|
||||||
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||||
|
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||||
|
# This approach ensures accurate LoRA application across various layer sizes and
|
||||||
|
# configurations, adapting to different model architectures and parallelization strategies.
|
||||||
|
#
|
||||||
|
# Example scenarios where this is necessary:
|
||||||
|
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||||
|
# 2. We're processing the last segment which might be smaller.
|
||||||
|
# 3. Different projection layers (q, k, v) have different sizes.
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
|
else:
|
||||||
|
proj = result
|
||||||
|
|
||||||
|
for r, rank_segments in data.rank_data.items():
|
||||||
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
|
if lora_a_ptr is None or lora_b_ptr is None:
|
||||||
|
raise ValueError("LoRA data is missing")
|
||||||
|
|
||||||
|
if data.use_sgmv:
|
||||||
|
# Use SGMV for prefill
|
||||||
|
v = lora_a_sgmv_cutlass(
|
||||||
|
input,
|
||||||
|
rank_segments.tmp_shrink,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
lora_b_sgmv_cutlass(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
rank_segments.tmp_expand,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use BGMV for decode
|
||||||
|
v = torch.zeros(
|
||||||
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
|
add_lora_a_bgmv(
|
||||||
|
v,
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
add_lora_b_bgmv(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
result[:, start_idx:end_idx] += proj
|
||||||
|
else:
|
||||||
|
for adapter_index in adapter_data.meta.adapter_set:
|
||||||
|
if data is not None and data.has_adapter(adapter_index):
|
||||||
|
adapter_mask = (
|
||||||
|
(adapter_data.meta.adapter_indices == adapter_index)
|
||||||
|
.to(input.dtype)
|
||||||
|
.view(-1, 1)
|
||||||
|
)
|
||||||
|
layer_result = self.forward_lora(
|
||||||
|
input, data, adapter_index, adapter_mask
|
||||||
|
)
|
||||||
|
result[:, start_idx:end_idx] += layer_result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward_lora(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
data: "BatchLoraWeights",
|
||||||
|
adapter_index: int,
|
||||||
|
adapter_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||||
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||||
|
|
||||||
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
|
a_out = input @ lora_a
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
a_out = self.collect_lora_a(a_out)
|
||||||
|
|
||||||
|
result = (a_out @ lora_b) * adapter_mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("Implemented in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.sizes = sizes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
return TensorParallelMultiAdapterLinear(
|
||||||
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# noop if no layer names are provided (e.g. for models without adapters)
|
||||||
|
if self.layer_names is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# handle models like Bloom that have inputs of shape
|
||||||
|
# (batch_size, sequence_length, hidden_size)
|
||||||
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
|
# for the LoRA computation, then reshape back
|
||||||
|
prev_shape = result.shape
|
||||||
|
is_3d = len(input.shape) >= 3
|
||||||
|
if is_3d:
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
|
if self.sizes is not None:
|
||||||
|
offset += self.sizes[i]
|
||||||
|
end_idx = offset // self.process_group.size()
|
||||||
|
else:
|
||||||
|
end_idx = result.shape[1]
|
||||||
|
|
||||||
|
result = self.forward_layer_type(
|
||||||
|
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_3d:
|
||||||
|
result = result.reshape(prev_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||||
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
gathered_tensors = [
|
||||||
|
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||||
|
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||||
|
return cls(base_layer, layer_id, layer_name, process_group)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
if self.layer_name is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||||
|
stride = result.shape[-1] // self.process_group.size()
|
||||||
|
start_idx = self.process_group.rank() * stride
|
||||||
|
end_idx = (self.process_group.rank() + 1) * stride
|
||||||
|
|
||||||
|
self.forward_layer_type(
|
||||||
|
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||||
|
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||||
|
return a_out
|
191
backends/gaudi/server/text_generation_server/layers/medusa.py
Normal file
191
backends/gaudi/server/text_generation_server/layers/medusa.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
from text_generation_server.layers.linear import FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
|
TensorParallelHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, medusa_config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
|
for i in range(get_speculate())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHead(torch.nn.Module):
|
||||||
|
def __init__(self, config, medusa_config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
|
for i in range(medusa_config["medusa_num_layers"])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
n = len(self.blocks)
|
||||||
|
self.out = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHeadV1(nn.Module):
|
||||||
|
def __init__(self, lm_head, medusa):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.medusa = medusa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
speculator = config.speculator
|
||||||
|
|
||||||
|
path = speculator["path"]
|
||||||
|
medusa_config = str(Path(path) / "config.json")
|
||||||
|
|
||||||
|
for fname in speculator["model_paths"]:
|
||||||
|
filename = str(Path(path) / fname)
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MedusaHeadV1(lm_head, medusa)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if input.shape[0] > 128:
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
speculative_logits = self.medusa(input)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHeadV2(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
|
medusa_config = str(Path(speculator_path) / "config.json")
|
||||||
|
filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
self.n_medusa_heads = get_speculate()
|
||||||
|
|
||||||
|
assert medusa_config["medusa_num_layers"] == 1
|
||||||
|
self.linear = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
|
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if x.shape[0] > 128:
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
size = x.shape[-1]
|
||||||
|
block_size = (size + self.world_size - 1) // self.world_size
|
||||||
|
start = self.rank * block_size
|
||||||
|
stop = (self.rank + 1) * block_size
|
||||||
|
|
||||||
|
x_block = x[:, start:stop]
|
||||||
|
|
||||||
|
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||||
|
medusa_res = self.act(self.linear(x)).reshape(
|
||||||
|
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply all residual medusa heads
|
||||||
|
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||||
|
|
||||||
|
# Gather medusa heads
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
|
||||||
|
# Stack x and medusa residual x
|
||||||
|
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||||
|
|
||||||
|
# Compute lm head on x + medusa residual x
|
||||||
|
logits = self.lm_head(stacked_x)
|
||||||
|
|
||||||
|
# Finally, split logits from speculative logits
|
||||||
|
logits, speculative_logits = torch.split(
|
||||||
|
logits, [1, self.n_medusa_heads], dim=-2
|
||||||
|
)
|
||||||
|
# Squeeze added dimension
|
||||||
|
logits = logits.squeeze(-2)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
282
backends/gaudi/server/text_generation_server/layers/mlp.py
Normal file
282
backends/gaudi/server/text_generation_server/layers/mlp.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
A L2 normalization implementation
|
||||||
|
...
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
normalized_shape : int
|
||||||
|
Dimensionality of input data (size of final tensor axis)
|
||||||
|
elementwise_scale_weight : torch.Tensor
|
||||||
|
learned scaling term after normalization?
|
||||||
|
elementwise_shift_bias : torch.Tensor
|
||||||
|
learned bias term after normalization?
|
||||||
|
eps : float
|
||||||
|
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
eps=1e-06,
|
||||||
|
):
|
||||||
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
|
self.weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
self.bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
x = self.weight * x
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
INV_SQRT2 = 2**-0.5
|
||||||
|
|
||||||
|
|
||||||
|
def simple_norm(x: torch.Tensor, eps=1e-06):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
return x * INV_SQRT2
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
|
||||||
|
self.proj0 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.0",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj1 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
|
||||||
|
self.ln = MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.0",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb(ind)
|
||||||
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
|
if i == 0:
|
||||||
|
state = self.proj0(state) * self.state_weight + z
|
||||||
|
else:
|
||||||
|
state = self.proj1(state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln(state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head(state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.proj = [
|
||||||
|
FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.{i}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
self.head = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb[i](ind)
|
||||||
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorHead(nn.Module):
|
||||||
|
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.mlp_speculator = mlp_speculator
|
||||||
|
self.scale_input = scale_input
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if input.shape[0] > 128:
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
input_ids = logits.argmax(dim=-1)
|
||||||
|
if self.scale_input:
|
||||||
|
input = simple_norm(input)
|
||||||
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
|
for fname in config.speculator["model_paths"]:
|
||||||
|
filename = str(Path(speculator_path) / fname)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
tie_weights = config.speculator_config.get("tie_weights", False)
|
||||||
|
if tie_weights:
|
||||||
|
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||||
|
else:
|
||||||
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
|
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
|
||||||
|
scale_input = config.speculator_config.get("scale_input", False)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
|
@ -0,0 +1,250 @@
|
|||||||
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
Weights,
|
||||||
|
UnquantizedWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||||
|
# class inheritance is whacky.
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class MoELayer(Protocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
): ...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||||
|
their outputs based on the calculated routing. This layer is much slower
|
||||||
|
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||||
|
available (e.g. for unsupported quantizers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert scoring_func is None, "scoring func is not handled"
|
||||||
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.n_experts = n_experts
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if "gelu" in hidden_act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in hidden_act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
|
self.gate_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.up_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.down_proj = [
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gating_output: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
gating_output,
|
||||||
|
self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
x, gating_output, self.topk, self.renormalize
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
weights = torch.zeros(
|
||||||
|
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||||
|
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for i in range(self.n_experts):
|
||||||
|
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||||
|
h = self.down_proj[i](h, reduce=False)
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that uses fused kernels to only apply the active experts
|
||||||
|
for each token (rather than applying all experts and selecting the
|
||||||
|
outputs of active experts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
and weights.loader.to_fp8
|
||||||
|
):
|
||||||
|
cls = FP8SparseMoELayer
|
||||||
|
else:
|
||||||
|
cls = UnquantizedSparseMoELayer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Using MoE layer wih fused gemm",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moe = cls(
|
||||||
|
n_expert_group=n_expert_group,
|
||||||
|
n_experts=n_experts,
|
||||||
|
prefix=prefix,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk=topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
down_proj_name=down_proj_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.moe(x, gating_output=gating_output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_supported(weights: Weights) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader)
|
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
fp8_quantize,
|
||||||
|
quant_dtype,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .unquantized import fused_moe
|
||||||
|
except Exception:
|
||||||
|
fused_moe = None
|
||||||
|
|
||||||
|
|
||||||
|
class FP8SparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.gate_up_proj_input_scale,
|
||||||
|
) = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||||
|
_load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
|
w2_scale=self.down_proj_weight_scale,
|
||||||
|
a1_scale=self.gate_up_proj_input_scale,
|
||||||
|
a2_scale=self.down_proj_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights(
|
||||||
|
get_weight_fn,
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
all_weight_scales = None
|
||||||
|
max_input_scale = None
|
||||||
|
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = get_weight_fn(prefix, i, name, weights)
|
||||||
|
|
||||||
|
assert isinstance(weight, Fp8Weight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=quant_dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
if all_weight_scales is None:
|
||||||
|
all_weight_scales = torch.empty(
|
||||||
|
(n_experts,) + weight.weight_scale.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
|
normalize_e4m3fn_to_native_float8(
|
||||||
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if current_input_scale is not None:
|
||||||
|
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||||
|
max_input_scale = current_input_scale
|
||||||
|
else:
|
||||||
|
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||||
|
weight.weight, scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight, all_weight_scales, max_input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||||
|
)
|
@ -16,10 +16,8 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
|
||||||
def grouped_topk(
|
def grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -50,3 +48,18 @@ def grouped_topk(
|
|||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights = torch.nn.functional.softmax(
|
||||||
|
gating_output, dim=1, dtype=torch.float32
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
@ -0,0 +1,121 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = _load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||||
|
for i in range(n_experts):
|
||||||
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
|
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_weights_row(
|
||||||
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
606
backends/gaudi/server/text_generation_server/layers/rotary.py
Normal file
606
backends/gaudi/server/text_generation_server/layers/rotary.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_inv_freq(dim, base, device):
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rope_config(config):
|
||||||
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
|
rope_scaling = {
|
||||||
|
"type": os.environ["ROPE_SCALING"],
|
||||||
|
"factor": float(os.environ["ROPE_FACTOR"]),
|
||||||
|
}
|
||||||
|
return rope_scaling
|
||||||
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
|
||||||
|
super().__init__()
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
||||||
|
# to query hidden dimension, so the original tensors need to be
|
||||||
|
# expanded
|
||||||
|
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
||||||
|
# and expansion of cos/sin tensors via concatenation
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||||
|
cos = torch.cat((cos, cos), dim=-1)
|
||||||
|
sin = torch.cat((sin, sin), dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def static(cls, config, dim, base, device):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
scaling_factor = None
|
||||||
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if not hasattr(config, "max_position_embeddings") and hasattr(
|
||||||
|
config, "max_seq_len"
|
||||||
|
):
|
||||||
|
# handling for dbrx
|
||||||
|
config.max_position_embeddings = config.max_seq_len
|
||||||
|
if rope_scaling is not None:
|
||||||
|
# `rope_type` is now standard in transformers, but some existing models
|
||||||
|
# have `type` instead.
|
||||||
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
||||||
|
|
||||||
|
if rope_type == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_type == "default":
|
||||||
|
pass
|
||||||
|
elif rope_type == "mrope":
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
if mrope_section is not None:
|
||||||
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
mrope_section,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
elif rope_type == "dynamic":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=base,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif rope_type == "llama3":
|
||||||
|
inv_freq = apply_llama3_scaling(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor=rope_scaling["factor"],
|
||||||
|
low_freq_factor=rope_scaling["low_freq_factor"],
|
||||||
|
high_freq_factor=rope_scaling["high_freq_factor"],
|
||||||
|
original_max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
elif rope_type == "yarn":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
base=base,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
|
)
|
||||||
|
elif rope_type in ["su", "longrope"]:
|
||||||
|
short_factor = torch.tensor(
|
||||||
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
short_inv_freq = 1.0 / (
|
||||||
|
short_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
long_factor = torch.tensor(
|
||||||
|
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
long_inv_freq = 1.0 / (
|
||||||
|
long_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
original_max_position_embeddings = (
|
||||||
|
config.original_max_position_embeddings
|
||||||
|
)
|
||||||
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
if max_position_embeddings <= original_max_position_embeddings:
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scale = max_position_embeddings / original_max_position_embeddings
|
||||||
|
scaling_factor = math.sqrt(
|
||||||
|
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||||
|
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||||
|
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||||
|
short_mscale = rope_scaling["short_mscale"]
|
||||||
|
long_mscale = rope_scaling["long_mscale"]
|
||||||
|
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
short_mscale=short_mscale,
|
||||||
|
long_mscale=long_mscale,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SuRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix, weights):
|
||||||
|
# XXX: Always load this in float32 !
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
|
weights.dtype = dtype
|
||||||
|
|
||||||
|
scaling_factor = None
|
||||||
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_scaling["type"] == "dynamic":
|
||||||
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
t /= self.scaling_factor
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||||
|
|
||||||
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
|
|
||||||
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
short_inv_freq,
|
||||||
|
long_inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
original_max_position_embeddings,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
|
self.short_inv_freq = short_inv_freq
|
||||||
|
self.long_inv_freq = long_inv_freq
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached is None
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
|
short_freqs = torch.outer(
|
||||||
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.cat([short_freqs, long_freqs])
|
||||||
|
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
short_inv_freq: torch.Tensor,
|
||||||
|
long_inv_freq: torch.Tensor,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
short_mscale: float,
|
||||||
|
long_mscale: float,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
):
|
||||||
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
|
self.short_inv_freq = short_inv_freq
|
||||||
|
self.long_inv_freq = long_inv_freq
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.short_mscale = short_mscale
|
||||||
|
self.long_mscale = long_mscale
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
|
||||||
|
# cache
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached is None
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
|
|
||||||
|
short_freqs = torch.outer(
|
||||||
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
short_freqs = short_freqs * self.short_mscale
|
||||||
|
long_freqs = long_freqs * self.long_mscale
|
||||||
|
|
||||||
|
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
||||||
|
freqs[: self.original_max_position_embeddings] = short_freqs
|
||||||
|
freqs[self.original_max_position_embeddings :] = long_freqs
|
||||||
|
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings:
|
||||||
|
newbase = self.base * (
|
||||||
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||||
|
- (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
self.inv_freq = _create_inv_freq(
|
||||||
|
self.dim, newbase, self.inv_freq.device
|
||||||
|
)
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||||
|
2 * math.log(base)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Find dim range bounds based on rotations
|
||||||
|
def find_correction_range(
|
||||||
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||||
|
):
|
||||||
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||||
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||||
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
|
def linear_ramp_mask(min, max, dim):
|
||||||
|
if min == max:
|
||||||
|
max += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
device,
|
||||||
|
scaling_factor,
|
||||||
|
*,
|
||||||
|
extrapolation_factor,
|
||||||
|
attn_factor,
|
||||||
|
beta_fast,
|
||||||
|
beta_slow,
|
||||||
|
mscale: float,
|
||||||
|
mscale_all_dim: float,
|
||||||
|
):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(
|
||||||
|
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||||
|
)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
self.mscale_all_dim = mscale_all_dim
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.mscale = float(
|
||||||
|
get_mscale(self.scaling_factor, mscale)
|
||||||
|
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||||
|
* self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings or True:
|
||||||
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
|
self.dim, self.base, self.inv_freq.device
|
||||||
|
)
|
||||||
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
|
low, high = find_correction_range(
|
||||||
|
self.beta_fast,
|
||||||
|
self.beta_slow,
|
||||||
|
self.dim,
|
||||||
|
self.base,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
inv_freq_mask = (
|
||||||
|
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_llama3_scaling(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
*,
|
||||||
|
scaling_factor: int,
|
||||||
|
low_freq_factor: int,
|
||||||
|
high_freq_factor: int,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
):
|
||||||
|
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||||
|
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / scaling_factor)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||||
|
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inv_freq: torch.Tensor,
|
||||||
|
scaling_factor: float,
|
||||||
|
sections: list,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
|
self.sections = sections
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self.section_indices = (
|
||||||
|
torch.arange(len(self.sections))
|
||||||
|
.repeat_interleave(torch.tensor(self.sections))
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.to(inv_freq.device)
|
||||||
|
)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||||
|
):
|
||||||
|
# always cache the cos/sin for the full sequence length to avoid
|
||||||
|
# recomputing if the sequence length is smaller than the cached one
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
||||||
|
|
||||||
|
def get_cos_sin(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
slen = position_ids.shape[0]
|
||||||
|
|
||||||
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
return cos, sin
|
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
||||||
|
from text_generation_server.layers.mlp import MLPSpeculatorHead
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeHead(torch.nn.Module):
|
||||||
|
def __init__(self, lm_head, speculator):
|
||||||
|
super().__init__()
|
||||||
|
self.head = lm_head
|
||||||
|
self.speculator = speculator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
speculator = config.speculator
|
||||||
|
if speculator:
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
speculator_config = str(speculator_path / "config.json")
|
||||||
|
|
||||||
|
with open(speculator_config, "r") as f:
|
||||||
|
speculator_config = json.load(f)
|
||||||
|
|
||||||
|
config.speculator_config = speculator_config
|
||||||
|
try:
|
||||||
|
architecture = speculator_config["architectures"][0]
|
||||||
|
|
||||||
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||||
|
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
speculator = None
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
speculator = MedusaHeadV1.load(config, prefix, weights)
|
||||||
|
except Exception:
|
||||||
|
speculator = MedusaHeadV2(config, prefix, weights)
|
||||||
|
lm_head = None
|
||||||
|
else:
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
speculator = None
|
||||||
|
return SpeculativeHead(lm_head, speculator)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if self.speculator is not None:
|
||||||
|
return self.speculator(input)
|
||||||
|
|
||||||
|
assert self.head is not None
|
||||||
|
logits = self.head(input)
|
||||||
|
return logits, None
|
@ -0,0 +1,244 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Iterable, List
|
||||||
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
|
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
|
class LayerConcat(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Apply multiple layers to the input and concatenate their
|
||||||
|
outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
|
||||||
|
"""
|
||||||
|
`dim` is the dimension along which layer outputs are concatenated.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.layers = layers
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
outputs = [layer(x) for layer in self.layers]
|
||||||
|
return torch.cat(outputs, self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class SuperLayer(torch.nn.Module):
|
||||||
|
def __init__(self, linear):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelHead(SuperLayer):
|
||||||
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
|
super().__init__(linear)
|
||||||
|
self.process_group = process_group
|
||||||
|
self.should_gather = should_gather
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
if config.quantize == "exl2":
|
||||||
|
try:
|
||||||
|
# If the piece and LM head embeddings are shared, we have
|
||||||
|
# non-quantized weights...
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
except Exception:
|
||||||
|
# ...otherwise they are quantized.
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
should_gather = weights.process_group.size() > 1
|
||||||
|
elif weights.process_group.size() > 1:
|
||||||
|
try:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
should_gather = True
|
||||||
|
except AssertionError:
|
||||||
|
# If the vocab size is not divisible by number of shards
|
||||||
|
# just load the entire thing.
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
else:
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
|
||||||
|
return TensorParallelHead(
|
||||||
|
get_linear(weight, bias=None),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
should_gather=should_gather,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.should_gather:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
world_out = input.new_empty(1, out_dim * world_size)
|
||||||
|
local_out = input.new_empty(1, out_dim)
|
||||||
|
gather_input = local_out
|
||||||
|
else:
|
||||||
|
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||||
|
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||||
|
local_out = gather_input.T
|
||||||
|
|
||||||
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
return world_out
|
||||||
|
return world_out.T
|
||||||
|
|
||||||
|
output = super().forward(input)
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
return world_output
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
@classmethod
|
||||||
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_gate_up(prefix)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_qkv(
|
||||||
|
cls,
|
||||||
|
config,
|
||||||
|
prefix: str,
|
||||||
|
weights,
|
||||||
|
bias: bool,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
|
prefix,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||||
|
if config.quantize == "exl2":
|
||||||
|
linears = []
|
||||||
|
for prefix in prefixes:
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||||
|
linears.append(get_linear(weight, b))
|
||||||
|
linear = LayerConcat(linears)
|
||||||
|
else:
|
||||||
|
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||||
|
if bias:
|
||||||
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||||
|
bias = torch.cat(b, dim=dim)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelRowLinear(SuperLayer):
|
||||||
|
def __init__(self, linear, process_group):
|
||||||
|
super().__init__(linear)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(
|
||||||
|
get_linear(weight, bias),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
|
out = super().forward(input)
|
||||||
|
if self.process_group.size() > 1 and reduce:
|
||||||
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
|
super().__init__()
|
||||||
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
block_size = (num_embeddings + world_size - 1) // world_size
|
||||||
|
self.min_id = rank * block_size
|
||||||
|
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||||
|
self.null_idx = weight.shape[
|
||||||
|
0
|
||||||
|
] # Usually block_size, might be less in non even vocab_size.
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.reduce = reduce
|
||||||
|
|
||||||
|
"""Additional 0 entry used for masking"""
|
||||||
|
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||||
|
# translate for [0, self.max_id - self.min_id[
|
||||||
|
input = torch.where(
|
||||||
|
(self.min_id > input) | (input >= self.max_id),
|
||||||
|
self.null_idx,
|
||||||
|
input - self.min_id,
|
||||||
|
)
|
||||||
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
|
if self.reduce and self.process_group.size() > 1:
|
||||||
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
return out
|
994
backends/gaudi/server/text_generation_server/models/__init__.py
Normal file
994
backends/gaudi/server/text_generation_server/models/__init__.py
Normal file
@ -0,0 +1,994 @@
|
|||||||
|
# ruff: noqa: F821
|
||||||
|
# the above line disables the `undefined-name` rule for the model type variables
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.models.auto import modeling_auto
|
||||||
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
import enum
|
||||||
|
|
||||||
|
# Needed to properly setup habana_frameworks
|
||||||
|
|
||||||
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
|
from text_generation_server.models.bloom import BLOOM
|
||||||
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
|
PhiMoEConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
AdapterParameters,
|
||||||
|
build_layer_weight_lookup,
|
||||||
|
load_and_merge_adapters,
|
||||||
|
AdapterInfo,
|
||||||
|
)
|
||||||
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Model",
|
||||||
|
"CausalLM",
|
||||||
|
"Seq2SeqLM",
|
||||||
|
"get_model_with_lora_adapters",
|
||||||
|
]
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
if ATTENTION == "paged":
|
||||||
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
|
FlashDeepseekV2ForCausalLM,
|
||||||
|
DeepseekV2Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||||
|
FlashDeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
|
FlashCohereForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
FlashGemmaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
|
FlashGemma2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.pali_gemma import (
|
||||||
|
PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
|
FlashMllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llava_next import (
|
||||||
|
FlashLlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
|
FlashSantacoderForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
|
FlashGPT2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
|
||||||
|
FlashGPTJForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||||
|
Qwen2_5VLForConditionalGeneration,
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
|
except ImportError as e:
|
||||||
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
__all__.append(FlashCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
DEEPSEEK_V2 = {
|
||||||
|
"type": "deepseek_v2",
|
||||||
|
"name": "Deepseek V2",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
|
}
|
||||||
|
DEEPSEEK_V3 = {
|
||||||
|
"type": "deepseek_v3",
|
||||||
|
"name": "Deepseek V3",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||||
|
}
|
||||||
|
IDEFICS2 = {
|
||||||
|
"type": "idefics2",
|
||||||
|
"name": "Idefics 2",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
IDEFICS3 = {
|
||||||
|
"type": "idefics3",
|
||||||
|
"name": "Idefics 3",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAVA_NEXT = {
|
||||||
|
"type": "llava_next",
|
||||||
|
"name": "Llava Next (1.6)",
|
||||||
|
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAMA = {
|
||||||
|
"type": "llama",
|
||||||
|
"name": "Llama",
|
||||||
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
|
}
|
||||||
|
PHI3 = {
|
||||||
|
"type": "phi3",
|
||||||
|
"name": "Phi 3",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
}
|
||||||
|
GRANITE = {
|
||||||
|
"type": "granite",
|
||||||
|
"name": "Granite",
|
||||||
|
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
}
|
||||||
|
GEMMA = {
|
||||||
|
"type": "gemma",
|
||||||
|
"name": "Gemma",
|
||||||
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
|
GEMMA2 = {
|
||||||
|
"type": "gemma2",
|
||||||
|
"name": "Gemma2",
|
||||||
|
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||||
|
}
|
||||||
|
COHERE = {
|
||||||
|
"type": "cohere",
|
||||||
|
"name": "Cohere",
|
||||||
|
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||||
|
}
|
||||||
|
DBRX = {
|
||||||
|
"type": "dbrx",
|
||||||
|
"name": "Dbrx",
|
||||||
|
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||||
|
}
|
||||||
|
MAMBA = {
|
||||||
|
"type": "mamba",
|
||||||
|
"name": "Mamba",
|
||||||
|
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||||
|
}
|
||||||
|
MISTRAL = {
|
||||||
|
"type": "mistral",
|
||||||
|
"name": "Mistral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
}
|
||||||
|
MIXTRAL = {
|
||||||
|
"type": "mixtral",
|
||||||
|
"name": "Mixtral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||||
|
}
|
||||||
|
GPT_BIGCODE = {
|
||||||
|
"type": "gpt_bigcode",
|
||||||
|
"name": "Gpt Bigcode",
|
||||||
|
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||||
|
}
|
||||||
|
PHI = {
|
||||||
|
"type": "phi",
|
||||||
|
"name": "Phi",
|
||||||
|
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||||
|
}
|
||||||
|
PHI_MOE = {
|
||||||
|
"type": "phimoe",
|
||||||
|
"name": "PhiMoe",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
}
|
||||||
|
BAICHUAN = {
|
||||||
|
"type": "baichuan",
|
||||||
|
"name": "Baichuan",
|
||||||
|
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
}
|
||||||
|
FALCON = {
|
||||||
|
"type": "falcon",
|
||||||
|
"name": "Falcon",
|
||||||
|
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||||
|
}
|
||||||
|
STARCODER2 = {
|
||||||
|
"type": "starcoder2",
|
||||||
|
"name": "StarCoder 2",
|
||||||
|
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||||
|
}
|
||||||
|
QWEN2 = {
|
||||||
|
"type": "qwen2",
|
||||||
|
"name": "Qwen 2",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
|
||||||
|
}
|
||||||
|
QWEN2_VL = {
|
||||||
|
"type": "qwen2_vl",
|
||||||
|
"name": "Qwen 2 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||||
|
}
|
||||||
|
QWEN2_5_VL = {
|
||||||
|
"type": "qwen2_5_vl",
|
||||||
|
"name": "Qwen 2.5 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
|
||||||
|
}
|
||||||
|
GALACTICA = {
|
||||||
|
"type": "galactica",
|
||||||
|
"name": "Galactica",
|
||||||
|
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||||
|
}
|
||||||
|
SANTACODER = {
|
||||||
|
"type": "santacoder",
|
||||||
|
"name": "SantaCoder",
|
||||||
|
"url": "https://huggingface.co/bigcode/santacoder",
|
||||||
|
}
|
||||||
|
GPT2 = {
|
||||||
|
"type": "gpt2",
|
||||||
|
"name": "Gpt2",
|
||||||
|
"url": "https://huggingface.co/openai-community/gpt2",
|
||||||
|
}
|
||||||
|
GPT_NEOX = {
|
||||||
|
"type": "gpt_neox",
|
||||||
|
"name": "Gpt Neox",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||||
|
}
|
||||||
|
GPTJ = {
|
||||||
|
"type": "gptj",
|
||||||
|
"name": "Gptj",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||||
|
}
|
||||||
|
MLLAMA = {
|
||||||
|
"type": "mllama",
|
||||||
|
"name": "Mllama",
|
||||||
|
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__GLOBALS = locals()
|
||||||
|
for data in ModelType:
|
||||||
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
|
|
||||||
|
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
|
||||||
|
# Disable gradients
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
) -> Model:
|
||||||
|
global FLASH_ATTENTION
|
||||||
|
|
||||||
|
if speculate is not None:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(0)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
speculator = None
|
||||||
|
if "medusa_num_heads" in config_dict:
|
||||||
|
medusa_model_id = model_id
|
||||||
|
medusa_revision = revision
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_medusa:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_medusa)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
is_local = Path(medusa_model_id).exists()
|
||||||
|
if not is_local:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
medusa_model_id, revision=medusa_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
medusa_model_id,
|
||||||
|
revision=medusa_revision,
|
||||||
|
filename="medusa_lm_head.safetensors",
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(medusa_config).parent,
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
speculator = {
|
||||||
|
"path": Path(medusa_model_id),
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
|
||||||
|
method = "medusa"
|
||||||
|
elif model_type == "mlp_speculator":
|
||||||
|
mlp_model_id = model_id
|
||||||
|
mlp_revision = revision
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_mlp = config_dict["n_predict"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_mlp:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_mlp)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
is_local = Path(mlp_model_id).exists()
|
||||||
|
extension = ".safetensors"
|
||||||
|
if not is_local:
|
||||||
|
mlp_speculator_config = hf_hub_download(
|
||||||
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
api = HfApi()
|
||||||
|
info = api.model_info(mlp_model_id, revision=mlp_revision)
|
||||||
|
filenames = [
|
||||||
|
s.rfilename
|
||||||
|
for s in info.siblings
|
||||||
|
if s.rfilename.endswith(extension)
|
||||||
|
and len(s.rfilename.split("/")) == 1
|
||||||
|
and "arguments" not in s.rfilename
|
||||||
|
and "args" not in s.rfilename
|
||||||
|
and "training" not in s.rfilename
|
||||||
|
]
|
||||||
|
for filename in filenames:
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
speculator_dir_path = Path(mlp_speculator_config).parent
|
||||||
|
# if these are downloaded, they get converted to safetensors
|
||||||
|
filenames.extend(
|
||||||
|
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(mlp_speculator_config).parent,
|
||||||
|
"model_paths": filenames,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
speculator = Path(mlp_model_id)
|
||||||
|
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
|
||||||
|
speculator = {"path": speculator, "model_paths": filenames}
|
||||||
|
method = "mlp_speculator"
|
||||||
|
else:
|
||||||
|
method = "n-gram"
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
|
if speculate > 0:
|
||||||
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
if model_type == DEEPSEEK_V2:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV2Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif model_type == DEEPSEEK_V3:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV3Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
model_type == GPT_BIGCODE
|
||||||
|
or model_type == GPT2
|
||||||
|
and model_id.startswith("bigcode/")
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
|
)
|
||||||
|
elif model_type == GPT2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPTJ:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTJForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPT_NEOX:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
elif model_type == PHI:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == PHI_MOE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
config_class=PhiMoEConfig,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == BAICHUAN:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == COHERE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == DBRX:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
|
||||||
|
and not sharded
|
||||||
|
and not config_dict.get("alibi", False)
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
|
)
|
||||||
|
elif model_type == MISTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMistralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == MIXTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == STARCODER2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_5_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=Qwen2_5_VLConfig,
|
||||||
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
elif model_type == MLLAMA:
|
||||||
|
return FlashMllamaCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMllamaForConditionalGeneration,
|
||||||
|
batch_class=FlashMllamaCausalLMBatch,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS2:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS3:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
|
elif model_type == PALIGEMMA:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
elif model_type == LLAVA_NEXT:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_class=FlashLlavaNextForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
|
LlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
|
if SDP_ON_BF16 == 1:
|
||||||
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
|
if model_type == "gpt_bigcode":
|
||||||
|
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
|
if model_type == "bloom":
|
||||||
|
return BLOOM(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "llava_next":
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=None,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "mllama":
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_class=MllamaForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=None,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
|
||||||
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
|
def get_model_with_lora_adapters(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
adapter_to_index: Dict[str, int],
|
||||||
|
):
|
||||||
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||||
|
model = get_model(
|
||||||
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lora_adapters) > 0:
|
||||||
|
target_to_layer = build_layer_weight_lookup(model.model)
|
||||||
|
|
||||||
|
for index, adapter in enumerate(lora_adapters):
|
||||||
|
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
|
||||||
|
# At the moment, we only support loading a single adapter into the model, but we keep the
|
||||||
|
# AdapterParameters object for easier extension in the future.
|
||||||
|
adapter_parameters = AdapterParameters(
|
||||||
|
adapter_info=[adapter],
|
||||||
|
# when merging multiple adapters we can weight them differently
|
||||||
|
# if this is not set, all adapters will be weighted equally
|
||||||
|
# see: text_generation_server.utils.merges.strategies for impl
|
||||||
|
weights=None,
|
||||||
|
merge_strategy=0,
|
||||||
|
density=1.0,
|
||||||
|
majority_sign_method=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index = index + 1
|
||||||
|
adapter_to_index[adapter.id] = adapter_index
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||||
|
)
|
||||||
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||||
|
(
|
||||||
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_and_merge_adapters(
|
||||||
|
model.model_id,
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_index,
|
||||||
|
weight_names,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names = adapter_weight_names.copy()
|
||||||
|
|
||||||
|
adapter_layers = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"qkv_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer_name in adapter_layers:
|
||||||
|
nlayers = (
|
||||||
|
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||||
|
)
|
||||||
|
adapter_weights = LoraWeights.prepare_weights(
|
||||||
|
config=adapter_config,
|
||||||
|
module_map=module_map,
|
||||||
|
layer_type=layer_name,
|
||||||
|
unused_weight_names=unused_weight_names,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dtype=model.dtype,
|
||||||
|
world_size=model.world_size,
|
||||||
|
process_group=model.process_group,
|
||||||
|
target_to_layer=target_to_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_weights is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||||
|
adapter_index, adapter_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(unused_weight_names) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_tokenizer is not None:
|
||||||
|
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||||
|
|
||||||
|
model.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
|
return model
|
52
backends/gaudi/server/text_generation_server/models/bloom.py
Normal file
52
backends/gaudi/server/text_generation_server/models/bloom.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "CausalLMBatch":
|
||||||
|
batch = super().from_pb(
|
||||||
|
pb=pb,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
batch.keys_head_dim_last = False
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class BLOOM(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(BLOOM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
|
return BloomCausalLMBatch
|
1426
backends/gaudi/server/text_generation_server/models/causal_lm.py
Normal file
1426
backends/gaudi/server/text_generation_server/models/causal_lm.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,923 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch BLOOM model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
)
|
||||||
|
from transformers import BloomConfig, PreTrainedModel
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
)
|
||||||
|
|
||||||
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from custom_kernels import fused_bloom_attention_cuda
|
||||||
|
|
||||||
|
CUSTOM_KERNELS_ENABLED = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
||||||
|
_CONFIG_FOR_DOC = "BloomConfig"
|
||||||
|
|
||||||
|
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
"bigscience/bigscience-small-testing",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"bigscience/bloom-1b1",
|
||||||
|
"bigscience/bloom-1b7",
|
||||||
|
"bigscience/bloom-3b",
|
||||||
|
"bigscience/bloom-7b1",
|
||||||
|
"bigscience/bloom",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_causal_mask(
|
||||||
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||||
|
) -> torch.BoolTensor:
|
||||||
|
"""
|
||||||
|
Make causal mask used for self-attention.
|
||||||
|
"""
|
||||||
|
batch_size, target_length = input_ids_shape
|
||||||
|
mask = torch.ones(
|
||||||
|
(target_length, target_length + past_key_values_length),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
mask = mask.triu(1 + past_key_values_length)
|
||||||
|
|
||||||
|
expanded_mask = mask.unsqueeze(0).expand(
|
||||||
|
batch_size, target_length, target_length + past_key_values_length
|
||||||
|
)
|
||||||
|
return expanded_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||||
|
"""
|
||||||
|
batch_size, src_length = mask.shape
|
||||||
|
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||||
|
|
||||||
|
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
||||||
|
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
||||||
|
|
||||||
|
|
||||||
|
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||||
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||||
|
`softmax(l+a) = softmax(l)`. Based on
|
||||||
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||||
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||||
|
num_heads (`int`, *required*):
|
||||||
|
number of heads
|
||||||
|
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||||
|
dtype of the output tensor
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = attention_mask.shape
|
||||||
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||||
|
base = torch.tensor(
|
||||||
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||||
|
device=attention_mask.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
powers = torch.arange(
|
||||||
|
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
device=attention_mask.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(
|
||||||
|
1,
|
||||||
|
1 + 2 * num_remaining_heads,
|
||||||
|
2,
|
||||||
|
device=attention_mask.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
|
||||||
|
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||||
|
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||||
|
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||||
|
# => the query_length dimension will then be broadcasted correctly
|
||||||
|
# This is more or less identical to T5's relative position bias:
|
||||||
|
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||||
|
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||||
|
alibi = slopes[..., None] * arange_tensor
|
||||||
|
return alibi
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
|
def dropout_add(
|
||||||
|
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Dropout add function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.tensor`, *required*):
|
||||||
|
input tensor
|
||||||
|
residual (`torch.tensor`, *required*):
|
||||||
|
esidual tensor
|
||||||
|
prob (`float`, *required*):
|
||||||
|
dropout probability
|
||||||
|
training (`bool`, *required*):
|
||||||
|
training mode
|
||||||
|
"""
|
||||||
|
out = F.dropout(x, p=prob, training=training)
|
||||||
|
out = residual + out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script # this is shit for unknow reasons.
|
||||||
|
def _split_heads(
|
||||||
|
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||||
|
storage as `fused_qkv`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
value: [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||||
|
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
|
||||||
|
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
|
||||||
|
|
||||||
|
query_layer = query_layer.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_heads, seq_length, head_dim
|
||||||
|
)
|
||||||
|
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
|
||||||
|
batch_size * num_heads, head_dim, seq_length
|
||||||
|
)
|
||||||
|
value_layer = value_layer.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_heads, seq_length, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
return query_layer, key_layer, value_layer
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
|
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Merge heads together over the last dimenstion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
||||||
|
"""
|
||||||
|
# What we want to achieve is:
|
||||||
|
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||||
|
batch_size_and_num_heads, seq_length, _ = x.shape
|
||||||
|
batch_size = batch_size_and_num_heads // num_heads
|
||||||
|
|
||||||
|
# First view to decompose the batch size
|
||||||
|
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
||||||
|
x = x.view(batch_size, num_heads, seq_length, head_dim)
|
||||||
|
|
||||||
|
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||||
|
return x.reshape(batch_size, seq_length, num_heads * head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config: BloomConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pretraining_tp = config.pretraining_tp
|
||||||
|
self.slow_but_exact = config.slow_but_exact
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.n_head
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.split_size = self.hidden_size
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
|
if self.head_dim * self.num_heads != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Layer-wise attention scaling
|
||||||
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
self.beta = 1.0
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
if self.num_heads % process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.dense = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_attention(
|
||||||
|
fused_qkv: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
head_mask: Optional[torch.Tensor],
|
||||||
|
beta: float,
|
||||||
|
inv_norm_factor: float,
|
||||||
|
num_heads: int,
|
||||||
|
use_cache: bool,
|
||||||
|
):
|
||||||
|
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
|
||||||
|
head_dim = three_times_hidden_size // (3 * num_heads)
|
||||||
|
batch_size * num_heads
|
||||||
|
|
||||||
|
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
|
||||||
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
(query_layer, key_layer, value_layer) = _split_heads(
|
||||||
|
fused_qkv, num_heads=num_heads, head_dim=head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key, past_value = layer_past
|
||||||
|
# concatenate along seq_length dimension:
|
||||||
|
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||||
|
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||||
|
past_key = past_key.view(-1, *past_key.shape[-2:])
|
||||||
|
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||||
|
past_value = past_value.view(-1, *past_value.shape[-2:])
|
||||||
|
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||||
|
|
||||||
|
_, _, kv_length = key_layer.shape
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
present = (key_layer, value_layer)
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
###
|
||||||
|
|
||||||
|
# [batch_size * num_heads, q_length, kv_length]
|
||||||
|
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||||
|
attention_scores = alibi.baddbmm(
|
||||||
|
batch1=query_layer,
|
||||||
|
batch2=key_layer,
|
||||||
|
beta=beta,
|
||||||
|
alpha=inv_norm_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||||
|
input_dtype = attention_scores.dtype
|
||||||
|
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||||
|
if input_dtype == torch.float16:
|
||||||
|
attention_scores = attention_scores.to(torch.float)
|
||||||
|
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
|
||||||
|
attn_weights = attention_scores.masked_fill_(
|
||||||
|
attention_mask, torch.finfo(attention_scores.dtype).min
|
||||||
|
)
|
||||||
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||||
|
input_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# # [batch_size, num_heads, q_length, kv_length]
|
||||||
|
# attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
|
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
|
||||||
|
|
||||||
|
# change view [batch_size, num_heads, q_length, head_dim]
|
||||||
|
context_layer = _merge_heads(
|
||||||
|
context_layer, num_heads=num_heads, head_dim=head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
return context_layer, present, attention_probs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
|
fused_qkv = self.query_key_value(
|
||||||
|
hidden_states
|
||||||
|
) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
batch_size, q_length, _ = fused_qkv.shape
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key, past_value = layer_past
|
||||||
|
layer_past = (
|
||||||
|
past_key.view(-1, *past_key.shape[-2:]),
|
||||||
|
past_value.view(-1, *past_value.shape[-2:]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
|
||||||
|
assert self.training is False, "Only foward pass was implemented"
|
||||||
|
assert (
|
||||||
|
attention_mask.shape[-1] < 4096
|
||||||
|
), "Custom kernel support only up to 4096 tokens"
|
||||||
|
(
|
||||||
|
context_layer,
|
||||||
|
present,
|
||||||
|
attention_probs,
|
||||||
|
) = fused_bloom_attention_cuda.forward(
|
||||||
|
fused_qkv,
|
||||||
|
layer_past,
|
||||||
|
alibi,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
self.beta,
|
||||||
|
self.inv_norm_factor,
|
||||||
|
self.num_heads,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_layer, present, attention_probs = self.compute_attention(
|
||||||
|
fused_qkv=fused_qkv,
|
||||||
|
layer_past=layer_past,
|
||||||
|
alibi=alibi,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
beta=self.beta,
|
||||||
|
inv_norm_factor=self.inv_norm_factor,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||||
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||||
|
slices = self.hidden_size / self.pretraining_tp
|
||||||
|
output_tensor = torch.zeros_like(context_layer)
|
||||||
|
for i in range(self.pretraining_tp):
|
||||||
|
output_tensor = output_tensor + F.linear(
|
||||||
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
|
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||||
|
output_tensor += residual
|
||||||
|
|
||||||
|
outputs = (output_tensor, present)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attention_probs,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BloomMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config: BloomConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pretraining_tp = config.pretraining_tp
|
||||||
|
self.slow_but_exact = config.slow_but_exact
|
||||||
|
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.dense_4h_to_h = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.gelu_impl = torch.nn.GELU(approximate="tanh")
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, residual: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||||
|
intermediate_output = torch.zeros_like(residual)
|
||||||
|
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||||
|
for i in range(self.pretraining_tp):
|
||||||
|
intermediate_output = intermediate_output + F.linear(
|
||||||
|
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
self.dense_4h_to_h.weight[
|
||||||
|
:, int(i * slices) : int((i + 1) * slices)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||||
|
|
||||||
|
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||||
|
intermediate_output += residual
|
||||||
|
|
||||||
|
return intermediate_output
|
||||||
|
|
||||||
|
|
||||||
|
class BloomBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, config: BloomConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
prefix = f"h.{layer_id}"
|
||||||
|
self.input_layernorm = LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
self.num_heads = config.n_head
|
||||||
|
self.self_attention = BloomAttention(
|
||||||
|
prefix=f"{prefix}.self_attention", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.apply_residual_connection_post_layernorm = (
|
||||||
|
config.apply_residual_connection_post_layernorm
|
||||||
|
)
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
|
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||||
|
|
||||||
|
# Layer norm at the beginning of the transformer layer.
|
||||||
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Layer norm post the self attention.
|
||||||
|
if self.apply_residual_connection_post_layernorm:
|
||||||
|
residual = layernorm_output
|
||||||
|
else:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
# Self attention.
|
||||||
|
attn_outputs = self.self_attention(
|
||||||
|
layernorm_output,
|
||||||
|
residual,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
alibi=alibi,
|
||||||
|
head_mask=head_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output = attn_outputs[0]
|
||||||
|
|
||||||
|
outputs = attn_outputs[1:]
|
||||||
|
|
||||||
|
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||||
|
|
||||||
|
# Get residual
|
||||||
|
if self.apply_residual_connection_post_layernorm:
|
||||||
|
residual = layernorm_output
|
||||||
|
else:
|
||||||
|
residual = attention_output
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
output = self.mlp(layernorm_output, residual)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs = (output,) + outputs
|
||||||
|
else:
|
||||||
|
outputs = (output,) + outputs[1:]
|
||||||
|
|
||||||
|
return outputs # hidden_states, present, attentions
|
||||||
|
|
||||||
|
|
||||||
|
class BloomPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = BloomConfig
|
||||||
|
base_model_prefix = "transformer"
|
||||||
|
_no_split_modules = ["BloomBlock"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_standard_cache(
|
||||||
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||||
|
num_heads, ...]))
|
||||||
|
"""
|
||||||
|
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||||
|
num_heads = batch_size_times_num_heads // batch_size
|
||||||
|
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||||
|
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
return tuple(
|
||||||
|
(
|
||||||
|
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||||
|
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||||
|
)
|
||||||
|
for layer_past in past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_bloom_cache(
|
||||||
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||||
|
batch_size_times_num_heads = batch_size * num_heads
|
||||||
|
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||||
|
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||||
|
return tuple(
|
||||||
|
(
|
||||||
|
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||||
|
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||||
|
)
|
||||||
|
for layer_past in past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomModel(BloomPreTrainedModel):
|
||||||
|
def __init__(self, config: BloomConfig, weights):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.n_head
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
|
prefix="word_embeddings", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.word_embeddings_layernorm = LayerNorm.load(
|
||||||
|
prefix="word_embeddings_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BloomBlock(layer_id=layer_id, config=config, weights=weights)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final Layer Norm
|
||||||
|
self.ln_f = LayerNorm.load(
|
||||||
|
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_attn_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_shape: Tuple[int, int],
|
||||||
|
past_key_values_length: int,
|
||||||
|
) -> torch.BoolTensor:
|
||||||
|
# create causal mask
|
||||||
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||||
|
combined_attention_mask = None
|
||||||
|
device = attention_mask.device
|
||||||
|
_, src_length = input_shape
|
||||||
|
|
||||||
|
if src_length > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape,
|
||||||
|
device=device,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask
|
||||||
|
if combined_attention_mask is None
|
||||||
|
else expanded_attn_mask | combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
|
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||||
|
self.word_embeddings = new_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**deprecated_arguments,
|
||||||
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||||
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||||
|
warnings.warn(
|
||||||
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||||
|
" passing `position_ids`.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if len(deprecated_arguments) > 0:
|
||||||
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
if past_key_values[0] is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[-1]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), device=hidden_states.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
|
alibi = build_alibi_tensor(attention_mask, self.num_heads)
|
||||||
|
|
||||||
|
causal_mask = self._prepare_attn_mask(
|
||||||
|
attention_mask,
|
||||||
|
input_shape=(batch_size, seq_length),
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self, "tp_rank"):
|
||||||
|
assert self.num_heads % self.tp_world_size == 0
|
||||||
|
block_size = self.num_heads // self.tp_world_size
|
||||||
|
alibi = alibi[
|
||||||
|
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
|
||||||
|
]
|
||||||
|
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
|
||||||
|
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
||||||
|
else:
|
||||||
|
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
|
||||||
|
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
|
||||||
|
|
||||||
|
alibi = alibi.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
alibi=alibi,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (
|
||||||
|
outputs[2 if use_cache else 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
presents,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForCausalLM(BloomPreTrainedModel):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = BloomModel(config, weights)
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="word_embeddings",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
# only last token for input_ids if past is not None
|
||||||
|
if past_key_values:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
||||||
|
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||||
|
past_key_values = self._convert_to_bloom_cache(past_key_values)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**deprecated_arguments,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
"""
|
||||||
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||||
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||||
|
warnings.warn(
|
||||||
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||||
|
" passing `position_ids`.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if len(deprecated_arguments) > 0:
|
||||||
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return (
|
||||||
|
CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
|
)
|
@ -0,0 +1,817 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_create_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_attention_mask,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPooling,
|
||||||
|
)
|
||||||
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
# TODO Should we TP this ?
|
||||||
|
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
self.num_positions = self.num_patches + 1
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
|
patch_embeds = self.patch_embedding(
|
||||||
|
pixel_values.to(dtype=target_dtype)
|
||||||
|
) # shape = [*, width, grid, grid]
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config: CLIPTextConfig):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||||
|
self.position_embedding = nn.Embedding(
|
||||||
|
config.max_position_embeddings, embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_length = (
|
||||||
|
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.token_embedding(input_ids)
|
||||||
|
|
||||||
|
position_embeddings = self.position_embedding(position_ids)
|
||||||
|
embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return (
|
||||||
|
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
]
|
||||||
|
* 3,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
query_states = query_states * self.scale
|
||||||
|
key_states = self._shape(key_states, -1, bsz)
|
||||||
|
value_states = self._shape(value_states, -1, bsz)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
|
key_states = key_states.view(*proj_shape)
|
||||||
|
value_states = value_states.view(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply the causal_attention_mask first
|
||||||
|
if causal_attention_mask is not None:
|
||||||
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
|
f" {causal_attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = (
|
||||||
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
+ causal_attention_mask
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = (
|
||||||
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
+ attention_mask
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
attn_probs = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = CLIPAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
causal_attention_mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states, attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
causal_attention_mask=causal_attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPPreTrainedModel(nn.Module):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = CLIPConfig
|
||||||
|
base_model_prefix = "clip"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.max_position_embeddings - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||||
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLIP_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.max_position_embeddings - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||||
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||||
|
return_loss (`bool`, *optional*):
|
||||||
|
Whether or not to return the contrastive loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||||
|
[`CLIPEncoderLayer`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: CLIPConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
CLIPEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
causal_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.embeddings = CLIPTextEmbeddings(config)
|
||||||
|
# Initialize weights and apply final processing with `self.post_init()`
|
||||||
|
self.encoder = CLIPEncoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
# For `pooled_output` computation
|
||||||
|
self.eos_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError("You have to specify input_ids")
|
||||||
|
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
||||||
|
|
||||||
|
# CLIP's text model uses causal mask, prepare it here.
|
||||||
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||||
|
causal_attention_mask = _create_4d_causal_attention_mask(
|
||||||
|
input_shape, hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
# expand attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
attention_mask = _prepare_4d_attention_mask(
|
||||||
|
attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
causal_attention_mask=causal_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs[0]
|
||||||
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||||
|
|
||||||
|
if self.eos_token_id == 2:
|
||||||
|
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||||
|
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
last_hidden_state[
|
||||||
|
torch.arange(
|
||||||
|
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||||
|
),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
|
||||||
|
dim=-1
|
||||||
|
),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||||
|
last_hidden_state[
|
||||||
|
torch.arange(
|
||||||
|
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||||
|
),
|
||||||
|
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||||
|
(
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
|
||||||
|
== self.eos_token_id
|
||||||
|
)
|
||||||
|
.int()
|
||||||
|
.argmax(dim=-1),
|
||||||
|
]
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel(CLIPPreTrainedModel):
|
||||||
|
config_class = CLIPTextConfig
|
||||||
|
|
||||||
|
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||||
|
|
||||||
|
def __init__(self, prefix, config: CLIPTextConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.text_model = CLIPTextTransformer(prefix, config)
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CLIPTextModel
|
||||||
|
|
||||||
|
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||||
|
```"""
|
||||||
|
|
||||||
|
return self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = CLIPVisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.pre_layrnorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.encoder = CLIPEncoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if pixel_values is None:
|
||||||
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(pixel_values)
|
||||||
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
)
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
# pooled_output = last_hidden_state[:, 0, :]
|
||||||
|
# pooled_output = self.post_layernorm(pooled_output)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPooling(
|
||||||
|
last_hidden_state=last_hidden_state,
|
||||||
|
# pooler_output=pooled_output,
|
||||||
|
# hidden_states=encoder_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionModel(CLIPPreTrainedModel):
|
||||||
|
config_class = CLIPVisionConfig
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
_no_split_modules = ["CLIPEncoderLayer"]
|
||||||
|
|
||||||
|
def __init__(self, config: CLIPVisionConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.vision_model = CLIPVisionTransformer(config)
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, CLIPVisionModel
|
||||||
|
|
||||||
|
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
|
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||||
|
```"""
|
||||||
|
|
||||||
|
return self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPModel(nn.Module):
|
||||||
|
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
text_config = config.text_config
|
||||||
|
vision_config = config.vision_config
|
||||||
|
|
||||||
|
self.projection_dim = config.projection_dim
|
||||||
|
self.text_embed_dim = text_config.hidden_size
|
||||||
|
self.vision_embed_dim = vision_config.hidden_size
|
||||||
|
|
||||||
|
self.text_model = CLIPTextTransformer(text_config)
|
||||||
|
self.vision_model = CLIPVisionTransformer(vision_config)
|
||||||
|
|
||||||
|
self.visual_projection = nn.Linear(
|
||||||
|
self.vision_embed_dim, self.projection_dim, bias=False
|
||||||
|
)
|
||||||
|
self.text_projection = nn.Linear(
|
||||||
|
self.text_embed_dim, self.projection_dim, bias=False
|
||||||
|
)
|
||||||
|
self.logit_scale = nn.Parameter(
|
||||||
|
torch.tensor(self.config.logit_scale_init_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_text_features(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||||||
|
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, CLIPModel
|
||||||
|
|
||||||
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||||
|
>>> text_features = model.get_text_features(**inputs)
|
||||||
|
```"""
|
||||||
|
text_outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = text_outputs[1]
|
||||||
|
text_features = self.text_projection(pooled_output)
|
||||||
|
|
||||||
|
return text_features
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||||||
|
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, CLIPModel
|
||||||
|
|
||||||
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> image_features = model.get_image_features(**inputs)
|
||||||
|
```"""
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = vision_outputs[1] # pooled_output
|
||||||
|
image_features = self.visual_projection(pooled_output)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, CLIPModel
|
||||||
|
|
||||||
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(
|
||||||
|
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||||
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||||
|
```"""
|
||||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = vision_outputs[1]
|
||||||
|
image_embeds = self.visual_projection(image_embeds)
|
||||||
|
|
||||||
|
text_embeds = text_outputs[1]
|
||||||
|
text_embeds = self.text_projection(text_embeds)
|
||||||
|
|
||||||
|
# normalized features
|
||||||
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# cosine similarity as logits
|
||||||
|
logit_scale = self.logit_scale.exp()
|
||||||
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||||
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
|
return logits_per_image, logits_per_text
|
@ -0,0 +1,493 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 Cohere team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
# Such controlflows may add some overhead.
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||||
|
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||||
|
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
|
|
||||||
|
|
||||||
|
class CohereLayerNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
super().__init__()
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
# Fake weights
|
||||||
|
self.ones = weight.new_ones(weight.shape[1])
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
|
)
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
|
hidden_states_minus_mean = hidden_states - mean
|
||||||
|
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||||
|
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
if config.attention_bias:
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.bias", dim=0)
|
||||||
|
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
]
|
||||||
|
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashCohereAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = CohereRotary.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.use_qk_norm = config.use_qk_norm
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = CohereLayerNorm(
|
||||||
|
prefix=f"{prefix}.q_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.k_norm = CohereLayerNorm(
|
||||||
|
prefix=f"{prefix}.k_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_norm = None
|
||||||
|
self.k_norm = None
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
query = query.reshape(-1, self.head_size)
|
||||||
|
key = key.reshape(-1, self.head_size)
|
||||||
|
query = self.q_norm(query.contiguous())
|
||||||
|
key = self.k_norm(key.contiguous())
|
||||||
|
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashCohereLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
self.self_attn = FlashCohereAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = FastLayerNorm.load_no_bias(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_hidden_states)
|
||||||
|
output = attn_output + mlp_output
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(output, group=self.process_group)
|
||||||
|
|
||||||
|
return output, res
|
||||||
|
|
||||||
|
|
||||||
|
class FlashCohereModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashCohereLayer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: torch.Tensor,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = FlashCohereModel(prefix, config, weights)
|
||||||
|
try:
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.embed_tokens",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.logit_scale = config.logit_scale
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
logits *= self.logit_scale
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits *= self.logit_scale
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,745 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple, Any
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attn_pdrop: float = 0,
|
||||||
|
clip_qkv: Optional[float] = None,
|
||||||
|
kv_n_heads: int = 1,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
self.clip_qkv = clip_qkv
|
||||||
|
self.kv_n_heads = kv_n_heads
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
for k in ["model_type"]:
|
||||||
|
if k in kwargs:
|
||||||
|
kwargs.pop(k)
|
||||||
|
if len(kwargs) != 0:
|
||||||
|
raise ValueError(f"Found unknown {kwargs=}")
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxFFNConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ffn_act_fn: Optional[dict] = None,
|
||||||
|
ffn_hidden_size: int = 3584,
|
||||||
|
moe_num_experts: int = 4,
|
||||||
|
moe_top_k: int = 1,
|
||||||
|
moe_jitter_eps: Optional[float] = None,
|
||||||
|
moe_loss_weight: float = 0.01,
|
||||||
|
moe_normalize_expert_weights: Optional[float] = 1,
|
||||||
|
uniform_expert_assignment: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if ffn_act_fn is None:
|
||||||
|
ffn_act_fn = {"name": "silu"}
|
||||||
|
self.ffn_act_fn = ffn_act_fn
|
||||||
|
self.ffn_hidden_size = ffn_hidden_size
|
||||||
|
self.moe_num_experts = moe_num_experts
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.moe_jitter_eps = moe_jitter_eps
|
||||||
|
self.moe_loss_weight = moe_loss_weight
|
||||||
|
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||||
|
self.uniform_expert_assignment = uniform_expert_assignment
|
||||||
|
|
||||||
|
if uniform_expert_assignment:
|
||||||
|
raise ValueError("`uniform_expert_assignment = True` is not supported")
|
||||||
|
|
||||||
|
for k in ["model_type"]:
|
||||||
|
if k in kwargs:
|
||||||
|
kwargs.pop(k)
|
||||||
|
if len(kwargs) != 0:
|
||||||
|
raise ValueError(f"Found unknown {kwargs=}")
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxConfig(PretrainedConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"hidden_size": "d_model",
|
||||||
|
"num_attention_heads": "n_heads",
|
||||||
|
"num_hidden_layers": "n_layers",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 2048,
|
||||||
|
n_heads: int = 16,
|
||||||
|
n_layers: int = 24,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
vocab_size: int = 32000,
|
||||||
|
resid_pdrop: float = 0.0,
|
||||||
|
emb_pdrop: float = 0.0,
|
||||||
|
attn_config: Optional[DbrxAttentionConfig] = None,
|
||||||
|
ffn_config: Optional[DbrxFFNConfig] = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
output_router_logits: bool = False,
|
||||||
|
router_aux_loss_coef: float = 0.05,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
if attn_config is None:
|
||||||
|
self.attn_config = DbrxAttentionConfig()
|
||||||
|
elif isinstance(attn_config, dict):
|
||||||
|
self.attn_config = DbrxAttentionConfig(**attn_config)
|
||||||
|
else:
|
||||||
|
self.attn_config = attn_config
|
||||||
|
|
||||||
|
if ffn_config is None:
|
||||||
|
self.ffn_config = DbrxFFNConfig()
|
||||||
|
elif isinstance(ffn_config, dict):
|
||||||
|
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
||||||
|
else:
|
||||||
|
self.ffn_config = ffn_config
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.emb_pdrop = emb_pdrop
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_heads(self):
|
||||||
|
# We can't use the attribute map, since this the number of KV
|
||||||
|
# heads is not top-level.
|
||||||
|
return self.attn_config.kv_n_heads
|
||||||
|
|
||||||
|
|
||||||
|
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.Wqkv",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
num_heads=config.n_heads,
|
||||||
|
num_key_value_heads=config.attn_config.kv_n_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts(config, prefix, weights):
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.ffn_config.ffn_hidden_size % world_size == 0
|
||||||
|
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
expert_size = config.ffn_config.ffn_hidden_size
|
||||||
|
block_size = expert_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
tensor = torch.empty(
|
||||||
|
(config.ffn_config.moe_num_experts * block_size, config.d_model),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}")
|
||||||
|
|
||||||
|
for i in range(config.ffn_config.moe_num_experts):
|
||||||
|
offset = i * expert_size
|
||||||
|
expert_slice = slice_[start + offset : stop + offset]
|
||||||
|
|
||||||
|
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||||
|
dtype=weights.dtype
|
||||||
|
).to(device=weights.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts_quantized(config, prefix, weights, cls):
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.ffn_config.ffn_hidden_size % world_size == 0
|
||||||
|
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
expert_size = config.ffn_config.ffn_hidden_size
|
||||||
|
block_size = expert_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}")
|
||||||
|
|
||||||
|
experts = []
|
||||||
|
for i in range(config.ffn_config.moe_num_experts):
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Dbrx does not support gptq/awq quantization yet."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
offset = i * expert_size
|
||||||
|
expert_slice = (
|
||||||
|
slice_[start + offset : stop + offset]
|
||||||
|
.to(dtype=weights.dtype)
|
||||||
|
.to(device=weights.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cls == TensorParallelRowLinear:
|
||||||
|
expert_slice = expert_slice.t().contiguous()
|
||||||
|
linear = get_linear(expert_slice, None)
|
||||||
|
experts.append(cls(linear, weights.process_group))
|
||||||
|
else:
|
||||||
|
linear = get_linear(expert_slice, None)
|
||||||
|
experts.append(cls(linear))
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
|
self.num_heads = config.n_heads
|
||||||
|
self.hidden_size = config.d_model
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.attn_config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.attn_config.kv_n_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
if self.clip_qkv is not None:
|
||||||
|
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
|
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxNormAttentionNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_1 = FastLayerNorm.load_no_bias(
|
||||||
|
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
|
||||||
|
)
|
||||||
|
self.self_attn = DbrxAttention(
|
||||||
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.norm_2 = FastLayerNorm.load_no_bias(
|
||||||
|
prefix=f"{prefix}.norm_2",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
|
||||||
|
|
||||||
|
return normed_attn_res_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def select_experts(
|
||||||
|
gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int
|
||||||
|
):
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
# weights, selected_experts: (sequence_length, top-k)
|
||||||
|
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
|
||||||
|
if moe_normalize_expert_weights:
|
||||||
|
weights = weights / torch.norm(
|
||||||
|
weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
weights = weights.view(-1)
|
||||||
|
selected_experts = selected_experts.view(-1)
|
||||||
|
|
||||||
|
return selected_experts, weights
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def round_up(x: torch.Tensor, value: int):
|
||||||
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSparseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.moe_normalize_expert_weights = (
|
||||||
|
config.ffn_config.moe_normalize_expert_weights
|
||||||
|
)
|
||||||
|
self.hidden_dim = config.d_model
|
||||||
|
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
|
||||||
|
self.num_experts = config.ffn_config.moe_num_experts
|
||||||
|
self.top_k = config.ffn_config.moe_top_k
|
||||||
|
|
||||||
|
act = config.ffn_config.ffn_act_fn["name"]
|
||||||
|
if "gelu" in act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[act]
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(
|
||||||
|
config, f"{prefix}.router.layer", weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
|
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
|
||||||
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
|
)
|
||||||
|
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
|
||||||
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
|
)
|
||||||
|
self.wv1 = torch.cat([w1, v1], dim=1)
|
||||||
|
self.w2 = (
|
||||||
|
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||||
|
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
|
||||||
|
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
out = self.hpu_fused_moe(x, router_logits, self.top_k)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.moe_normalize_expert_weights = (
|
||||||
|
config.ffn_config.moe_normalize_expert_weights
|
||||||
|
)
|
||||||
|
self.hidden_dim = config.d_model
|
||||||
|
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
|
||||||
|
self.num_experts = config.ffn_config.moe_num_experts
|
||||||
|
self.top_k = config.ffn_config.moe_top_k
|
||||||
|
|
||||||
|
act = config.ffn_config.ffn_act_fn["name"]
|
||||||
|
if "gelu" in act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[act]
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(
|
||||||
|
config, f"{prefix}.router.layer", weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.w1 = _load_experts_quantized(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.experts.mlp.w1",
|
||||||
|
weights=weights,
|
||||||
|
cls=TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
self.w2 = _load_experts_quantized(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.experts.mlp.w2",
|
||||||
|
weights=weights,
|
||||||
|
cls=TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
self.v1 = _load_experts_quantized(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.experts.mlp.v1",
|
||||||
|
weights=weights,
|
||||||
|
cls=TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gate_logits: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# gate_logits: (sequence_length, n_experts)
|
||||||
|
gate_logits = self.gate(x)
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
|
||||||
|
if self.top_k < self.num_experts:
|
||||||
|
_, not_selected_experts = torch.topk(
|
||||||
|
weights,
|
||||||
|
self.num_experts - self.top_k,
|
||||||
|
largest=False,
|
||||||
|
sorted=False,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
# Mask not selected experts
|
||||||
|
weights.scatter_(1, not_selected_experts, 0)
|
||||||
|
|
||||||
|
# Re-normalize
|
||||||
|
if self.moe_normalize_expert_weights:
|
||||||
|
weights = weights / torch.norm(
|
||||||
|
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
weights = weights.to(x.dtype)
|
||||||
|
|
||||||
|
# Final output tensor
|
||||||
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
h = self.act(self.w1[i](x)) * self.v1[i](x)
|
||||||
|
h = self.w2[i](h, reduce=False)
|
||||||
|
# Add expert output to out with masking
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.blocks.{layer_id}"
|
||||||
|
|
||||||
|
self.attn = DbrxNormAttentionNorm(
|
||||||
|
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||||
|
self.moe = moe_cls(f"{prefix}.ffn", config, weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
# Self Attention
|
||||||
|
attn_output, attn_res = self.attn(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_output = self.moe(attn_output)
|
||||||
|
|
||||||
|
return moe_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.wte", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DbrxLayer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
|
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].attn.self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDbrxForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.model = DbrxModel(prefix, config, weights)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,633 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV2Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.moe_layer = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV2MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV2Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV2MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV2Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV2Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,642 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV3Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
if config.topk_method == "noaux_tc":
|
||||||
|
self.gate.e_score_correction_bias = torch.zeros(
|
||||||
|
config.n_routed_experts, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
|
self.moe_layer = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=config.scoring_func,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV3Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV3Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV3Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,555 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=256128,
|
||||||
|
hidden_size=3072,
|
||||||
|
intermediate_size=24576,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2FastRMSNorm(FastRMSNorm):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
weights.dtype = dtype
|
||||||
|
new = cls(weight, eps)
|
||||||
|
new.dtype = dtype
|
||||||
|
return new
|
||||||
|
|
||||||
|
# perform the multiplication in full precision and downcast after
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.head_dim
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma2Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
self.causal = causal
|
||||||
|
if is_sliding:
|
||||||
|
self.window_size = config.sliding_window
|
||||||
|
else:
|
||||||
|
self.window_size = -1
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.softmax_scale = config.query_pre_attn_scalar**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.softcap = config.attn_logit_softcapping
|
||||||
|
|
||||||
|
query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
self.head_size * config.num_attention_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
window_size_left=self.window_size,
|
||||||
|
softcap=self.softcap,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
softcap=self.softcap,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_activation
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma2Layer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = FlashGemma2Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
causal=causal,
|
||||||
|
is_sliding=is_sliding,
|
||||||
|
)
|
||||||
|
self.mlp = Gemma2MLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = Gemma2FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = Gemma2FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.pre_feedforward_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_feedforward_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
|
||||||
|
normed_attn_res_output = normed_attn_res_output + res
|
||||||
|
res = normed_attn_res_output
|
||||||
|
|
||||||
|
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||||
|
mlp_output = self.mlp(pre_normed, adapter_data)
|
||||||
|
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||||
|
|
||||||
|
return post_hidden_states, normed_attn_res_output
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma2Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashGemma2Layer(
|
||||||
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
causal=causal,
|
||||||
|
is_sliding=layer_id % 2 == 0,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = Gemma2FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
adapter_data: Optional[torch.Tensor],
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
embed_norm = config.hidden_size**0.5
|
||||||
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
|
self.model = FlashGemma2Model(
|
||||||
|
prefix=prefix, config=config, weights=weights, causal=causal
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.embed_tokens"
|
||||||
|
if config.tie_word_embeddings
|
||||||
|
else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.softcap = config.final_logit_softcapping
|
||||||
|
assert isinstance(self.softcap, float)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
logits /= self.softcap
|
||||||
|
logits = torch.tanh(logits)
|
||||||
|
logits *= self.softcap
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,469 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=256128,
|
||||||
|
hidden_size=3072,
|
||||||
|
intermediate_size=24576,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
weights.dtype = dtype
|
||||||
|
new = cls(weight, eps)
|
||||||
|
new.dtype = dtype
|
||||||
|
return new
|
||||||
|
|
||||||
|
# perform the multiplication in full precision and downcast after
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.head_dim
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemmaAttention(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
causal=self.causal,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaMLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemmaLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = FlashGemmaAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||||
|
)
|
||||||
|
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = GemmaFastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = GemmaFastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemmaModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashGemmaLayer(
|
||||||
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = GemmaFastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
adapter_data: Optional[torch.Tensor],
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
embed_norm = config.hidden_size**0.5
|
||||||
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
|
self.model = FlashGemmaModel(
|
||||||
|
prefix=prefix, config=config, weights=weights, causal=causal
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.embed_tokens"
|
||||||
|
if config.tie_word_embeddings
|
||||||
|
else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,451 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
|
||||||
|
|
||||||
|
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
return _load_qkv_gptq(
|
||||||
|
config,
|
||||||
|
prefix,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_qkv_gptq(config, prefix: str, weights):
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
# Weights
|
||||||
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
|
f"{prefix}.c_attn",
|
||||||
|
config.num_attention_heads,
|
||||||
|
config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bias
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
total_size = shape[0]
|
||||||
|
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||||
|
single_size = total_size // 3
|
||||||
|
assert single_size % world_size == 0
|
||||||
|
block_size = single_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensors = []
|
||||||
|
for i in range(3):
|
||||||
|
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||||
|
tensors.append(tensor)
|
||||||
|
bias = torch.cat(tensors, dim=0)
|
||||||
|
bias = bias.to(device=weights.device)
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||||
|
"""Load QKV from a single, transposed matrix."""
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
total_size = shape[1]
|
||||||
|
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
single_size = total_size // 3
|
||||||
|
assert single_size % world_size == 0
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
# Weights
|
||||||
|
block_size = single_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensors = []
|
||||||
|
for i in range(3):
|
||||||
|
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
||||||
|
tensors.append(tensor)
|
||||||
|
weight = torch.cat(tensors, dim=1).T
|
||||||
|
weight = weight.to(dtype=weights.dtype)
|
||||||
|
weight = weight.to(device=weights.device)
|
||||||
|
|
||||||
|
# Bias
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
total_size = shape[0]
|
||||||
|
single_size = total_size // 3
|
||||||
|
block_size = single_size // world_size
|
||||||
|
assert single_size % world_size == 0
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
b = []
|
||||||
|
for i in range(3):
|
||||||
|
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||||
|
b.append(tensor)
|
||||||
|
bias = torch.cat(b, dim=0)
|
||||||
|
bias = bias.to(dtype=weights.dtype)
|
||||||
|
bias = bias.to(device=weights.device)
|
||||||
|
assert list(bias.shape) == [
|
||||||
|
3 * num_heads * head_size
|
||||||
|
], f"{weight.shape} != {[3 * num_heads * head_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||||
|
|
||||||
|
|
||||||
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
"""load_row, but with transposed weight matrices."""
|
||||||
|
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
weight = weights.get_weights_row(prefix)
|
||||||
|
else:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
|
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
return TensorParallelRowLinear(
|
||||||
|
get_linear(weight, bias), process_group=weights.process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_col(config, prefix: str, weights, bias: bool):
|
||||||
|
"""load_col, but with transposed weight matrices."""
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
weight = weights.get_multi_weights_col([prefix], dim=1)
|
||||||
|
else:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPT2Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.query_key_value = load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=prefix,
|
||||||
|
weights=weights,
|
||||||
|
head_size=self.head_size,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.o_proj = load_row(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
|
self.head_size * self.num_heads, dim=1
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.activation_function
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.c_fc = load_col(
|
||||||
|
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.c_proj = load_row(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_size = (
|
||||||
|
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.c_fc(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
return self.c_proj(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPT2Layer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = FlashGPT2Attention(
|
||||||
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_2",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
hidden_states,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = attn_output + residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return residual + mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPT2Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashGPT2Layer(
|
||||||
|
prefix=(
|
||||||
|
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm.load(
|
||||||
|
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.embed_positions = TensorParallelEmbedding(
|
||||||
|
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = FlashGPT2Model(prefix, config, weights)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="wte" if not prefix else f"{prefix}.wte",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
token_embeds = self.embed_tokens(input_ids)
|
||||||
|
position_embeds = self.embed_positions(position_ids)
|
||||||
|
inputs_embeds = token_embeds + position_embeds
|
||||||
|
hidden_states = self.model(
|
||||||
|
inputs_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,389 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights):
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return TensorParallelRowLinear(linear, process_group=weights.process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJRotary(PositionRotaryEmbedding):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||||
|
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||||
|
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTJAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.rotary_dim = config.rotary_dim
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(
|
||||||
|
config,
|
||||||
|
prefix=prefix,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.o_proj = load_row(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = GPTJRotary.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.rotary_dim,
|
||||||
|
base=10000,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
|
self.head_size * self.num_heads, dim=1
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# Compute rotary embeddings on rotary_ndims
|
||||||
|
if self.rotary_dim is not None:
|
||||||
|
self.rotary_emb(
|
||||||
|
query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJMLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.activation_function
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc_in = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.fc_in", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc_out = load_row(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.fc_out",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.fc_in(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
return self.fc_out(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTJLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = FlashGPTJAttention(
|
||||||
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return attn_output + feed_forward_hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTJModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashGPTJLayer(
|
||||||
|
prefix=(
|
||||||
|
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_f = FastLayerNorm.load(
|
||||||
|
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor],
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTJForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
self.model = FlashGPTJModel(prefix, config, weights)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,658 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
KVCache,
|
||||||
|
get_kv_scales,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
|
# Only defined in granite.
|
||||||
|
bias = getattr(config, "attention_bias", False)
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = None
|
||||||
|
prefixes = None
|
||||||
|
|
||||||
|
if config.model_type == "phi3":
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
prefixes = ["qkv_proj"]
|
||||||
|
elif config.model_type == "baichuan":
|
||||||
|
prefix = f"{prefix}.W_pack"
|
||||||
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=prefix,
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
prefixes = [prefix]
|
||||||
|
else:
|
||||||
|
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def no_fp8(weights: Weights):
|
||||||
|
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||||
|
weights_loader = weights.weights_loader
|
||||||
|
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||||
|
weights_loader = HybridFP8UnquantLoader(
|
||||||
|
weights_loader.activation_scale_ub, to_fp8=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with weights.use_loader(weights_loader):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
# Setting defaults for baichuan custom config which doesn't apply them.
|
||||||
|
config.rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
config.num_key_value_heads = getattr(
|
||||||
|
config, "num_key_value_heads", config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# `config.attention_multiplier` is used in Granite
|
||||||
|
self.softmax_scale = getattr(
|
||||||
|
config, "attention_multiplier", self.head_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
if config.num_key_value_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=getattr(config, "attention_bias", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.moe = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.num_local_experts,
|
||||||
|
n_expert_group=None,
|
||||||
|
renormalize=True,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=None,
|
||||||
|
weights=weights,
|
||||||
|
gate_proj_name="w1",
|
||||||
|
up_proj_name="w3",
|
||||||
|
down_proj_name="w2",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x, adapter_data) -> torch.Tensor:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
out = self.moe(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights, index):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[self.hidden_act]
|
||||||
|
if "gelu" not in self.hidden_act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefixes = None
|
||||||
|
sizes = None
|
||||||
|
|
||||||
|
# Fuse gate and up proj
|
||||||
|
bias = getattr(config, "mlp_bias", False)
|
||||||
|
if config.model_type == "phi3":
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefixes = ["gate_proj", "up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
def __init__(self, index, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.self_attn = FlashLlamaAttention(
|
||||||
|
index=index,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.model_type == "phimoe":
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = Phi3MoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
|
)
|
||||||
|
# with moe the layernorms are are not rmsnorms and they have bias
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = LlamaMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
|
)
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||||
|
# but this would mean modifying the lora code
|
||||||
|
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
cross_attention_states,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
|
# Skip fp8 quant for first and last layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=0,
|
||||||
|
prefix=f"{prefix}.layers.0",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip first and last layers
|
||||||
|
for layer_id in range(1, config.num_hidden_layers - 1):
|
||||||
|
if layer_id in self.cross_attention_layers:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
|
FlashLlamaCrossLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaCrossLayer(
|
||||||
|
index=layer_id,
|
||||||
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=layer_id,
|
||||||
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with no_fp8(weights):
|
||||||
|
last_layer_id = config.num_hidden_layers - 1
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=last_layer_id,
|
||||||
|
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
cross_attention_states=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
cross_attention_states,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
|
super().__init__()
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=(
|
||||||
|
f"{name}.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.model = FlashLlamaModel(
|
||||||
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
|
if embedding_multiplier is not None:
|
||||||
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||||
|
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||||
|
try:
|
||||||
|
# Scale the weights directly
|
||||||
|
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||||
|
self.logits_scaled = True
|
||||||
|
except Exception:
|
||||||
|
self.logits_scaled = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states=None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.model(
|
||||||
|
inputs_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
if self.logits_scaling is not None and not self.logits_scaled:
|
||||||
|
logits /= self.logits_scaling
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits /= self.logits_scaling
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,285 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Llava-NeXT model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
|
"""
|
||||||
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`tuple`):
|
||||||
|
The size of the input image in the format (height, width).
|
||||||
|
grid_pinpoints (`List`):
|
||||||
|
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||||
|
of the form `(height, width)`.
|
||||||
|
patch_size (`int`):
|
||||||
|
The size of each image patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The shape of the image patch grid in the format (height, width).
|
||||||
|
"""
|
||||||
|
if not isinstance(grid_pinpoints, list):
|
||||||
|
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||||
|
|
||||||
|
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||||
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_image(tensor, original_size):
|
||||||
|
"""
|
||||||
|
Unpads a PyTorch tensor of a padded and resized image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):
|
||||||
|
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||||
|
original_size (`tuple`):
|
||||||
|
The original size of the image (height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The unpadded image tensor.
|
||||||
|
"""
|
||||||
|
original_height, original_width = original_size
|
||||||
|
current_height, current_width = tensor.shape[1:]
|
||||||
|
|
||||||
|
original_aspect_ratio = original_width / original_height
|
||||||
|
current_aspect_ratio = current_width / current_height
|
||||||
|
|
||||||
|
if original_aspect_ratio > current_aspect_ratio:
|
||||||
|
scale_factor = current_width / original_width
|
||||||
|
new_height = int(original_height * scale_factor)
|
||||||
|
padding = (current_height - new_height) // 2
|
||||||
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||||
|
else:
|
||||||
|
scale_factor = current_height / original_height
|
||||||
|
new_width = int(original_width * scale_factor)
|
||||||
|
padding = (current_width - new_width) // 2
|
||||||
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||||
|
|
||||||
|
return unpadded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||||
|
class LlavaNextMultiModalProjector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
|
self.linear_2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_features):
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
vision_config = config.vision_config
|
||||||
|
# Instead of selecting in hidden_states[-2].
|
||||||
|
# Instead compute only the n -2 + 1 layers and don't pool
|
||||||
|
if config.vision_feature_layer < 0:
|
||||||
|
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||||
|
self.vision_tower = load_vision_model(
|
||||||
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_newline = weights.get_tensor("image_newline")
|
||||||
|
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
self.config = config
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
mask = torch.where(input_ids == self.config.image_token_index)
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
try:
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused for this model
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
|
||||||
|
# 2. Merge text and images
|
||||||
|
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.view(
|
||||||
|
num_images * num_patches, channels, height, width
|
||||||
|
)
|
||||||
|
image_features = self.vision_tower(pixel_values)
|
||||||
|
|
||||||
|
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||||
|
# Already done within the clip model
|
||||||
|
selected_image_feature = image_features.last_hidden_state
|
||||||
|
|
||||||
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.config.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
# split up image_features for each of the individual images
|
||||||
|
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||||
|
# if we assume each image has 5 image features (base image + 4 patches)
|
||||||
|
split_sizes = [num_patches] * num_images
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
|
height = width = (
|
||||||
|
self.config.vision_config.image_size
|
||||||
|
// self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
|
||||||
|
if height * width != base_image_feature.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of patches is not consistent with the image size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimensions are intentionally swapped to be bug-compatible with
|
||||||
|
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.image_newline[:, None, None].expand(
|
||||||
|
*image_feature.shape[:-1], 1
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(image_feature, self.image_newline[None]), dim=0
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
|
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_features
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -0,0 +1,481 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfig(PretrainedConfig):
|
||||||
|
model_type = "mistral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAttention(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.max_past = (
|
||||||
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
|
)
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
if hasattr(config, "head_dim"):
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
self.head_size * config.num_attention_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralMLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[self.hidden_act]
|
||||||
|
if "gelu" not in self.hidden_act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = MistralAttention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
self.mlp = MistralMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class MistralModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MistralLayer(
|
||||||
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=(
|
||||||
|
f"{name}.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.model = MistralModel(
|
||||||
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
# TODO dirty hack for idefics2.
|
||||||
|
prefix=(
|
||||||
|
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.max_past = config.sliding_window
|
||||||
|
self.max_past_tensor = (
|
||||||
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
|
if self.max_past is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.model(
|
||||||
|
inputs_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
@ -0,0 +1,515 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralConfig(PretrainedConfig):
|
||||||
|
model_type = "mixtral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_local_experts=8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_local_experts = num_local_experts
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts(config, prefix: str, mat, weights):
|
||||||
|
if config.quantize is not None:
|
||||||
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
|
assert mat in ["w1", "w2", "w3"]
|
||||||
|
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.intermediate_size % world_size == 0
|
||||||
|
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
block_size = config.intermediate_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
tensor = torch.empty(
|
||||||
|
(config.num_local_experts * block_size, config.hidden_size),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(config.num_local_experts):
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
|
||||||
|
if mat == "w2":
|
||||||
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
|
else:
|
||||||
|
expert_slice = slice_[start:stop]
|
||||||
|
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||||
|
dtype=weights.dtype
|
||||||
|
).to(device=weights.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_past = (
|
||||||
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
|
)
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def select_experts(gate_logits: torch.Tensor, top_k: int):
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
# weights, selected_experts: (sequence_length, top-k)
|
||||||
|
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
|
||||||
|
weights /= weights.sum(dim=-1, keepdim=True)
|
||||||
|
weights = weights.view(-1)
|
||||||
|
selected_experts = selected_experts.view(-1)
|
||||||
|
|
||||||
|
return selected_experts, weights
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def round_up(x: torch.Tensor, value: int):
|
||||||
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.moe = moe_layer_cls(
|
||||||
|
n_expert_group=None,
|
||||||
|
n_experts=config.num_local_experts,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
renormalize=True,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=None,
|
||||||
|
weights=weights,
|
||||||
|
gate_proj_name="w1",
|
||||||
|
up_proj_name="w3",
|
||||||
|
down_proj_name="w2",
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe, MoELayer)
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
out = self.moe(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralLayer(nn.Module):
|
||||||
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = MixtralAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.moe = MixtralMoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_output = self.moe(normed_attn_res_output)
|
||||||
|
|
||||||
|
return moe_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralModel(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=(
|
||||||
|
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MixtralLayer(
|
||||||
|
"model" if not prefix else f"{prefix}.model",
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = MixtralModel(prefix, config, weights)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.max_past = config.sliding_window
|
||||||
|
self.max_past_tensor = (
|
||||||
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
|
if self.max_past is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
@ -0,0 +1,986 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Mllama model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask: torch.Tensor,
|
||||||
|
num_patches: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Expand aspect ratio mask to target_length
|
||||||
|
batch_size, max_num_tiles = aspect_ratio_mask.shape
|
||||||
|
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
|
||||||
|
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
|
||||||
|
|
||||||
|
# Mask padding patches
|
||||||
|
pad_patches = target_length - num_patches
|
||||||
|
attention_mask[:, :, -pad_patches:] = 0
|
||||||
|
|
||||||
|
# Invert the mask (0 -> 1, 1 -> 0)
|
||||||
|
attention_mask = 1 - attention_mask
|
||||||
|
|
||||||
|
# Reshape to 2D and create 4D attention mask
|
||||||
|
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size, max_num_tiles * target_length, 1
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length),
|
||||||
|
fill_value=min_dtype,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(
|
||||||
|
target_length, device=device
|
||||||
|
) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = (
|
||||||
|
causal_mask.clone()
|
||||||
|
) # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = (
|
||||||
|
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||||
|
:, :, :, :mask_length
|
||||||
|
].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_cross_attention_mask(
|
||||||
|
cross_attention_mask: torch.Tensor,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
dtype: str,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# reshape so it can be used by attn module
|
||||||
|
batch_size, text_total_length, *_ = cross_attention_mask.shape
|
||||||
|
cross_attention_mask = cross_attention_mask.repeat_interleave(
|
||||||
|
num_vision_tokens, dim=3
|
||||||
|
)
|
||||||
|
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
|
||||||
|
cross_attention_mask = cross_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# invert the mask
|
||||||
|
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
|
||||||
|
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
|
||||||
|
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
|
||||||
|
# last dimension contains negative infinity values, otherwise it's 1
|
||||||
|
negative_inf_value = torch.finfo(dtype).min
|
||||||
|
full_text_row_masked_out_mask = (
|
||||||
|
(cross_attention_mask != negative_inf_value)
|
||||||
|
.any(dim=-1)
|
||||||
|
.type_as(cross_attention_mask)[..., None]
|
||||||
|
)
|
||||||
|
cross_attention_mask *= full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
return cross_attention_mask, full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||||||
|
class MllamaVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionSdpaAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.head_dim = config.hidden_size // config.attention_heads
|
||||||
|
self.num_heads = config.attention_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, q_seq_len, _ = query.shape
|
||||||
|
_, kv_seq_len, _ = key.shape
|
||||||
|
|
||||||
|
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||||||
|
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||||
|
|
||||||
|
output = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_attention_heads = config.attention_heads
|
||||||
|
self.is_gated = is_gated
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
self.self_attn = MllamaVisionSdpaAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = MllamaVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
|
||||||
|
# there used to be an if else here, no code path
|
||||||
|
if is_gated:
|
||||||
|
self.gate_attn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gate_ffn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.input_layernorm(hidden_state)
|
||||||
|
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
||||||
|
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
||||||
|
hidden_state = residual + gate_attn * hidden_state
|
||||||
|
|
||||||
|
# Feed forward
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
||||||
|
hidden_state = self.mlp(hidden_state)
|
||||||
|
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
||||||
|
hidden_state = residual + gate_ffn * hidden_state
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoder(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = [
|
||||||
|
MllamaVisionEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=is_gated,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
encoder_states = [hidden_states]
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
encoder_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
|
||||||
|
self.embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embedding", weights=weights
|
||||||
|
)
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
embeddings = self.embedding(aspect_ratio_ids)
|
||||||
|
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
||||||
|
|
||||||
|
# Always gated.
|
||||||
|
embeddings = embeddings * self.gate.tanh()
|
||||||
|
|
||||||
|
hidden_state = hidden_state + embeddings
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# position embedding
|
||||||
|
embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
|
||||||
|
self.tile_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.tile_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# position embeddings
|
||||||
|
hidden_state = hidden_state + self.gated_position_embedding.view(
|
||||||
|
1, 1, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# precomputed tile position embeddings
|
||||||
|
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
tile_position_embedding = tile_position_embedding.reshape(
|
||||||
|
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
||||||
|
hidden_state = hidden_state + gated_tile_position_embedding
|
||||||
|
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.hidden_size,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.class_embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
|
||||||
|
prefix=f"{prefix}.gated_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.pre_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.post_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
## layer norms
|
||||||
|
self.layernorm_pre = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_pre",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
self.layernorm_post = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_post",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
|
||||||
|
## encoders
|
||||||
|
self.transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=False,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
self.global_transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.global_transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=True,
|
||||||
|
num_layers=config.num_global_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
aspect_ratio_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
) = pixel_values.shape
|
||||||
|
|
||||||
|
pixel_values = pixel_values.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||||
|
)
|
||||||
|
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# patch embedding
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# tile embeddings
|
||||||
|
_, num_patches, dim = hidden_state.shape
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, -1, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.pre_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply cls token
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.apply_class_embedding(hidden_state)
|
||||||
|
num_patches += 1
|
||||||
|
|
||||||
|
# apply position embeddings
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
||||||
|
|
||||||
|
# apply encoder
|
||||||
|
hidden_state = self.layernorm_pre(hidden_state)
|
||||||
|
|
||||||
|
# Compute the number of tokens to pad
|
||||||
|
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
||||||
|
# Compute padding tuple for pad function
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
num_padding_patches,
|
||||||
|
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
||||||
|
# Pad the tensor
|
||||||
|
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
||||||
|
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
attention_mask = _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask=attention_mask,
|
||||||
|
num_patches=self.num_patches,
|
||||||
|
target_length=hidden_state.shape[2],
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
||||||
|
hidden_state, all_intermediate_hidden_states = self.transformer(
|
||||||
|
hidden_state,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = [
|
||||||
|
hidden_state
|
||||||
|
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
|
||||||
|
if idx in self.intermediate_layers_indices
|
||||||
|
]
|
||||||
|
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
||||||
|
|
||||||
|
# apply global encoder
|
||||||
|
hidden_state = self.layernorm_post(hidden_state)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = self.post_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles * (num_patches + num_padding_patches),
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state, _ = self.global_transformer(
|
||||||
|
hidden_state, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state[:, :, :slice_index]
|
||||||
|
|
||||||
|
# adding intermediate layer outputs
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
||||||
|
)
|
||||||
|
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaTextCrossAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.num_heads = self.config.num_attention_heads
|
||||||
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = config.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
self.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.k_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.v_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.k_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
# past_key_value=None,
|
||||||
|
# attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
# hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
|
(
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
indices,
|
||||||
|
) = cross_attention_states
|
||||||
|
bs = cu_seqlen_q.size(0) - 1
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
|
||||||
|
key_states = self.k_proj(cross_attention_states)
|
||||||
|
value_states = self.v_proj(cross_attention_states)
|
||||||
|
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bs, -1, self.num_key_value_heads, self.head_size
|
||||||
|
)
|
||||||
|
key_states = self.k_norm(key_states)
|
||||||
|
|
||||||
|
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
|
||||||
|
causal = False
|
||||||
|
# logger.info(
|
||||||
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
|
# )
|
||||||
|
# execute sdpa
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=None,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=None,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||||||
|
class MllamaTextMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape = x.shape
|
||||||
|
gate_up_states = self.gate_up_proj(x)
|
||||||
|
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
|
||||||
|
result = self.down_proj(
|
||||||
|
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaCrossLayer(torch.nn.Module):
|
||||||
|
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, index) -> None:
|
||||||
|
layer_idx = index
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attn = MllamaTextCrossAttention(
|
||||||
|
prefix=f"{prefix}.cross_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.cross_attn_attn_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.post_attention_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
adapter_data,
|
||||||
|
cross_attention_states, # [ IB, ...]
|
||||||
|
hpu_attention_meta,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if cross_attention_states is None:
|
||||||
|
return hidden_states, residual
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
|
||||||
|
indices = cross_attention_states[-1]
|
||||||
|
out_hidden_states = hidden_states[:]
|
||||||
|
if len(indices) > 0:
|
||||||
|
assert max(indices) < hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states[indices]
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
# attention_mask=cross_attention_mask,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
out_hidden_states[indices] = hidden_states
|
||||||
|
hidden_states = out_hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||||
|
class MllamaTextRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = weight
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, *, prefix, weights, eps):
|
||||||
|
weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
return cls(weight=weight, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMllamaForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
config.text_config._attn_implementation = "sdpa"
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.vision_model = MllamaVisionModel(
|
||||||
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.multi_modal_projector = FastLinear.load(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.text_model = FlashLlamaForCausalLM(
|
||||||
|
prefix="language_model", config=config.text_config, weights=weights
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
|
||||||
|
if aspect_ratio_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
|
)
|
||||||
|
# logger.info(f"PIxel values {pixel_values.shape}")
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
vision_states = self.vision_model(
|
||||||
|
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||||
|
)
|
||||||
|
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
||||||
|
-1, vision_states.shape[-2], self.hidden_size
|
||||||
|
)
|
||||||
|
_, _, h = cross_attention_states.shape
|
||||||
|
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
|
||||||
|
# logger.info(f"cross {cross_attention_states.shape}")
|
||||||
|
return cross_attention_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
seqlen_q = len(image_indices)
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
device = cross_attention_states.device
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
offset = 0
|
||||||
|
cu_q = []
|
||||||
|
indices = []
|
||||||
|
for index in image_indices:
|
||||||
|
cu_q.append(offset)
|
||||||
|
length = seqlen.input_lengths[index].item()
|
||||||
|
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||||
|
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||||
|
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||||
|
offset += length
|
||||||
|
cu_q.append(offset)
|
||||||
|
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
assert max(indices) < input_ids.shape[0]
|
||||||
|
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
seqlen_q + 1, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
indices = image_indices[:]
|
||||||
|
|
||||||
|
cross_attention_states = (
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
@ -0,0 +1,420 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.rotary import (
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"num_key_value_heads": "num_attention_heads",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
if config.use_parallel_residual:
|
||||||
|
return linear
|
||||||
|
else:
|
||||||
|
return TensorParallelRowLinear(linear, process_group=weights.process_group)
|
||||||
|
|
||||||
|
|
||||||
|
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||||
|
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
# Only on non quantized versions
|
||||||
|
weight.weight = (
|
||||||
|
weight.weight.view(
|
||||||
|
num_heads,
|
||||||
|
3,
|
||||||
|
head_size,
|
||||||
|
hidden_size,
|
||||||
|
)
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
.reshape(-1, hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||||
|
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
if config.use_parallel_residual:
|
||||||
|
return linear
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear(linear)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashNeoxAttention(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
self.rotary_dim = int(config.rotary_pct * self.head_size)
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.rotary_dim,
|
||||||
|
base=config.rotary_emb_base,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
self.query_key_value = load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.query_key_value",
|
||||||
|
weights=weights,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
self.dense = load_row(
|
||||||
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# Compute rotary embeddings on rotary_ndims
|
||||||
|
query_rot = qkv[:, 0][..., : self.rotary_dim]
|
||||||
|
query_pass = qkv[:, 0][..., self.rotary_dim :]
|
||||||
|
key_rot = qkv[:, 1][..., : self.rotary_dim]
|
||||||
|
key_pass = qkv[:, 1][..., self.rotary_dim :]
|
||||||
|
|
||||||
|
# Inplace rotary
|
||||||
|
self.rotary_emb(query_rot, key_rot, cos, sin)
|
||||||
|
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=qkv[:, 1],
|
||||||
|
value=qkv[:, 2],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# sdpa
|
||||||
|
attn_output = attention(
|
||||||
|
query=qkv[:, 0],
|
||||||
|
key=qkv[:, 1],
|
||||||
|
value=qkv[:, 2],
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
qkv[:, 0],
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLP(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.dense_4h_to_h = load_row(
|
||||||
|
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.dense_4h_to_h(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashNeoXLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
|
prefix = f"gpt_neox.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.attention = FlashNeoxAttention(
|
||||||
|
config, prefix=f"{prefix}.attention", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
):
|
||||||
|
if self.use_parallel_residual:
|
||||||
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
attn_output = self.attention(
|
||||||
|
ln1_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(ln2_hidden_states)
|
||||||
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
|
return intermediate + hidden_states, None
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = GPTNeoXConfig
|
||||||
|
base_model_prefix = "gpt_neox"
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
_no_split_modules = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
|
def __init__(self, prefix: str, config, weights):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embed_in = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_in", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashNeoXLayer(layer_id, config, weights)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.final_layer_norm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.final_layer_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].attention.head_size
|
||||||
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "gpt_neox"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.gpt_neox"
|
||||||
|
|
||||||
|
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
|
||||||
|
|
||||||
|
self.embed_out = SpeculativeHead.load(
|
||||||
|
config, prefix="embed_out", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.gpt_neox(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.embed_out(hidden_states)
|
||||||
|
return logits
|
@ -0,0 +1,117 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
self.vision_tower = load_vision_model(
|
||||||
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix="multi_modal_projector.linear",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
text_config = config.text_config
|
||||||
|
text_config.speculator = config.speculator
|
||||||
|
text_config.quantize = config.quantize
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused here
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
position_ids += 1
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
|
||||||
|
# mask where image or padding tokens
|
||||||
|
mask = input_ids == self.config.image_token_index
|
||||||
|
|
||||||
|
# insert image features into input embeddings
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user