mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
Compare commits
481 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
8f8819795f | ||
|
95ccba3705 | ||
|
b400c275e4 | ||
|
84ab88d843 | ||
|
4645678ff0 | ||
|
ad765cd06b | ||
|
16b4b7974a | ||
|
459fbdebe3 | ||
|
449cee49ca | ||
|
73e797528d | ||
|
fe56f760df | ||
|
d62c941c56 | ||
|
9a8d0462e1 | ||
|
5861da1ad7 | ||
|
0b28aabb94 | ||
|
24bec29ffc | ||
|
37104acd75 | ||
|
87a0af4ec2 | ||
|
9c26b52940 | ||
|
d23b385eee | ||
|
d9bb9bebc9 | ||
|
3d059f91ab | ||
|
0142550096 | ||
|
f5f14dc660 | ||
|
54d15462dc | ||
|
2e60a8dd65 | ||
|
e5503eba78 | ||
|
e497bc09f6 | ||
|
67ce543e04 | ||
|
83fe45c15e | ||
|
11f2eec10e | ||
|
a35fbdb925 | ||
|
8c2c348f3c | ||
|
095775e05c | ||
|
0b3e3db043 | ||
|
f91434e99b | ||
|
8b91f92978 | ||
|
27ed848676 | ||
|
83ef364177 | ||
|
83b7b7bb92 | ||
|
c73ae0bd88 | ||
|
d4c6faa67b | ||
|
4ac06ddf56 | ||
|
f01dc9e743 | ||
|
5c5528e362 | ||
|
ed46c2c414 | ||
|
f74c36fe0d | ||
|
ae4451c3da | ||
|
b447f7e821 | ||
|
094975c3a8 | ||
|
dc5f05f8e6 | ||
|
124398fa57 | ||
|
c5ecc7a4de | ||
|
cae0cbe87d | ||
|
bbe218a4f7 | ||
|
58a65f7914 | ||
|
976eae216f | ||
|
622908deab | ||
|
55a6618434 | ||
|
036d802b62 | ||
|
8e92942a18 | ||
|
3208d1cd1d | ||
|
cdf70d6a28 | ||
|
ab9dafc68f | ||
|
31766dad77 | ||
|
ec35976f82 | ||
|
cb42b3ad83 | ||
|
491ed9e11d | ||
|
144d99c147 | ||
|
08bbfa16a1 | ||
|
d8ff7f2623 | ||
|
e88f6f6ee9 | ||
|
fa4e9511f8 | ||
|
a914a21899 | ||
|
aad9c2b0bd | ||
|
1f35cc7a31 | ||
|
683ff53fa3 | ||
|
5eec3a8bb6 | ||
|
b0069e0485 | ||
|
d7a24c03cf | ||
|
cea9dbc971 | ||
|
c00add9c03 | ||
|
97c5f7e685 | ||
|
1cae3197c4 | ||
|
3498f6085e | ||
|
142a49a80d | ||
|
06dfe9abfe | ||
|
ed96ba6503 | ||
|
feaa2477b7 | ||
|
230aa25641 | ||
|
9c89d0070e | ||
|
fde3234cbc | ||
|
d6a0c67e2f | ||
|
a7448661f7 | ||
|
5543fdc765 | ||
|
b8a4928d0e | ||
|
8a1cfd6122 | ||
|
794ec58b75 | ||
|
f0ed76583c | ||
|
cfd4fbb479 | ||
|
6df0fc0b55 | ||
|
d6881c37ab | ||
|
8a211dc7fc | ||
|
4cccce4b44 | ||
|
76bcb4948d | ||
|
b86c3947ab | ||
|
8a870b31b9 | ||
|
571ac9b507 | ||
|
4b8cda684b | ||
|
856709d5c3 | ||
|
36223f834e | ||
|
0ef8c8a97a | ||
|
c1cf36c0dc | ||
|
dd2bd5fdb3 | ||
|
88fd56f549 | ||
|
e3f2018cb5 | ||
|
bb69c5b199 | ||
|
c9d68945cc | ||
|
c07a2cc82b | ||
|
065aabb13d | ||
|
cb747b33da | ||
|
80e7d98f88 | ||
|
ee0dffcd14 | ||
|
4ef2e045c9 | ||
|
73b7cf83f6 | ||
|
eb3df0f46f | ||
|
c690da5973 | ||
|
db922eb77e | ||
|
40b00275b2 | ||
|
6cb41a80a1 | ||
|
d2ff68e98d | ||
|
d9dda11726 | ||
|
d937eb64da | ||
|
18c4607d46 | ||
|
29a0893b67 | ||
|
0a89902663 | ||
|
4e172028aa | ||
|
6ab02931cf | ||
|
cc212154e0 | ||
|
1dd346666a | ||
|
1d3c9beba8 | ||
|
2dfe3b3ee6 | ||
|
64a33c1f05 | ||
|
bdb3e488e4 | ||
|
17367438f3 | ||
|
b980848abf | ||
|
447a5b2f87 | ||
|
630f198624 | ||
|
8f6146f11a | ||
|
eecca27113 | ||
|
6e982f43a1 | ||
|
c20025dbf7 | ||
|
de19e7e844 | ||
|
d61f14f271 | ||
|
885144166f | ||
|
82f6ea1b71 | ||
|
5f78ec32a5 | ||
|
922cc38fbc | ||
|
120bd3e3bb | ||
|
1470aec9d9 | ||
|
203cade244 | ||
|
46994b34fb | ||
|
dc9b8e9814 | ||
|
3c7ae48f7f | ||
|
cc8b9650bd | ||
|
e07acc7f68 | ||
|
880ab9c2f3 | ||
|
1660154ae6 | ||
|
2e22164d4a | ||
|
83624a07be | ||
|
01067f8ba8 | ||
|
4f7e00f4ce | ||
|
da5ab46705 | ||
|
a9c7d2e3b6 | ||
|
afb6c728d8 | ||
|
d37a43e581 | ||
|
23bc38b10d | ||
|
ab5f616920 | ||
|
8f66d323d0 | ||
|
7eeefa3b57 | ||
|
a72f339c79 | ||
|
11ab329883 | ||
|
6f0b8c947d | ||
|
1708865fdc | ||
|
ea7f4082c4 | ||
|
3bb3fd19ae | ||
|
bf59118a93 | ||
|
c3bd7212c2 | ||
|
f01f2fb6e7 | ||
|
07b01293c5 | ||
|
cc66dccbe8 | ||
|
82c24f7420 | ||
|
a2d878fa0f | ||
|
b2fac5d947 | ||
|
a70dd2998b | ||
|
042791fbd5 | ||
|
27fa83ca5b | ||
|
a04356fb8c | ||
|
9f5c9a5e22 | ||
|
08f6fa0b59 | ||
|
d96dcb1797 | ||
|
5df8059037 | ||
|
8c3669b287 | ||
|
6685e8fcda | ||
|
e0db633396 | ||
|
b57f370386 | ||
|
2003d8be0c | ||
|
535149d872 | ||
|
2c74c55637 | ||
|
a35d1e6fe5 | ||
|
1d2cb356b9 | ||
|
d471805134 | ||
|
caff779dd4 | ||
|
892a26e549 | ||
|
72ab60fdd5 | ||
|
289aa48554 | ||
|
c637d68d74 | ||
|
780531ec77 | ||
|
e87893d38e | ||
|
ab7ccf5bc3 | ||
|
d5bc6a20bd | ||
|
d012f229c6 | ||
|
c5b5b3a11c | ||
|
faa10ad0bc | ||
|
8e0c161d0a | ||
|
3c54488638 | ||
|
6ee8d6dd3b | ||
|
07bed530f7 | ||
|
46a5a7e73e | ||
|
2fda8845a7 | ||
|
45013b60a4 | ||
|
bd6e8b3c13 | ||
|
5489406c4a | ||
|
2007a9473a | ||
|
b4ec427ad0 | ||
|
38cff84a3e | ||
|
3c9df21ff8 | ||
|
a5ecd6e586 | ||
|
fea62e928f | ||
|
52e48739a5 | ||
|
6489f85269 | ||
|
34a3bdedc3 | ||
|
4580ced091 | ||
|
003eaec0fb | ||
|
4f4857a4ac | ||
|
f9ee46f740 | ||
|
8442f1ac85 | ||
|
ca4f46ddfc | ||
|
a785000842 | ||
|
97f7a22f0b | ||
|
b1f9044d6c | ||
|
5eedb2ec7a | ||
|
9fde566602 | ||
|
aadc9cb485 | ||
|
a5593ba83e | ||
|
08c4184eb2 | ||
|
6e3220529d | ||
|
01dacf8e8f | ||
|
befd9f6735 | ||
|
46aeb0860d | ||
|
98330df65e | ||
|
513d19b955 | ||
|
3a9cdc3241 | ||
|
78ce618c70 | ||
|
90b226db29 | ||
|
0c9b6cdd76 | ||
|
2e4f4ba1bb | ||
|
8a8794a672 | ||
|
a6b02da971 | ||
|
6f88bd9390 | ||
|
0f346a3296 | ||
|
ba5fc7d922 | ||
|
db68bd0524 | ||
|
cece8635f8 | ||
|
43df056eee | ||
|
ed87b464b4 | ||
|
eab07f746c | ||
|
14a0df3a38 | ||
|
1b914f37e7 | ||
|
41c2623735 | ||
|
27ff1871b5 | ||
|
03c9388bf7 | ||
|
f58eb70ebf | ||
|
9c9ef37c56 | ||
|
058d3061f7 | ||
|
7f54b7336a | ||
|
5e0fb46821 | ||
|
153ff3740b | ||
|
8ec57558cd | ||
|
5f32dea1e2 | ||
|
1b97e084bf | ||
|
59ea38cbca | ||
|
5bbe1ce028 | ||
|
a6a0c97ed9 | ||
|
704a58c807 | ||
|
ffe05ccd05 | ||
|
ce7e356561 | ||
|
cf04a43fb1 | ||
|
58848cb471 | ||
|
7a82ddcbd0 | ||
|
51f5401893 | ||
|
3ea82d008c | ||
|
ce28ee88d5 | ||
|
0c478846c5 | ||
|
3dbdf63ec5 | ||
|
d912f0bf55 | ||
|
e36dfaa8de | ||
|
43f39f6894 | ||
|
9ed0c85fe1 | ||
|
8ad20daf33 | ||
|
6db3bcb700 | ||
|
64142489b6 | ||
|
8b295aa498 | ||
|
57f9685dc3 | ||
|
0da4df4b96 | ||
|
2358c2bb54 | ||
|
68103079f4 | ||
|
3011639ff7 | ||
|
f6e2f05b16 | ||
|
d22b0c1fbe | ||
|
2335459556 | ||
|
0204946d26 | ||
|
d18ed5cfc5 | ||
|
584b4d7a68 | ||
|
1c84a30fe6 | ||
|
d1f257ac56 | ||
|
93a7042d7e | ||
|
90a1d04a2f | ||
|
f9e561eced | ||
|
e790cfc0e4 | ||
|
afc7ded84f | ||
|
1028996fb3 | ||
|
5b6b74e21d | ||
|
0aa66d693a | ||
|
0b7df77178 | ||
|
7efcb5e0ed | ||
|
dd8691b7c5 | ||
|
c032280b17 | ||
|
75c8c54ac9 | ||
|
e6d29656b5 | ||
|
8024ded58f | ||
|
03263f5e88 | ||
|
3f14cd1420 | ||
|
c29dc89c18 | ||
|
0ff6ff60ad | ||
|
74d3ce106e | ||
|
d31a6f75cc | ||
|
10e6f29295 | ||
|
9263817c71 | ||
|
169178b937 | ||
|
7e2d18877e | ||
|
f478aa77ad | ||
|
abd24dd385 | ||
|
c103760172 | ||
|
f512021e77 | ||
|
ce85efa968 | ||
|
86984e3236 | ||
|
71e4268600 | ||
|
38fcafcf96 | ||
|
7774655297 | ||
|
9cca3e0b03 | ||
|
3ac7df2b6d | ||
|
628334d336 | ||
|
d95c670ada | ||
|
94304649f1 | ||
|
69e3be20fb | ||
|
dae3bf1d87 | ||
|
a4e3e8c608 | ||
|
eabbbbda23 | ||
|
c1fe28d694 | ||
|
aaea212d0f | ||
|
a3c9c62dc0 | ||
|
379472c4c2 | ||
|
2eb57a15ec | ||
|
0424e27f65 | ||
|
5cd8025f18 | ||
|
e279b38aca | ||
|
8b96a18265 | ||
|
deec30f893 | ||
|
6cb42f49ae | ||
|
47d7e34458 | ||
|
de2cdeca53 | ||
|
e4ab855480 | ||
|
d9fbbaafb0 | ||
|
9883f3b40e | ||
|
d5202c46f7 | ||
|
e415b690a6 | ||
|
4e821c003a | ||
|
8f99f165ce | ||
|
21187c27c9 | ||
|
2788d41a76 | ||
|
cfa73b5c99 | ||
|
30be188400 | ||
|
f3c5d7d92f | ||
|
358ceb67dd | ||
|
310778e02a | ||
|
9474415095 | ||
|
f5f11b797e | ||
|
b70ae0969f | ||
|
38773453ae | ||
|
e4201f44cf | ||
|
53729b74ac | ||
|
cb0a29484d | ||
|
c7ab1810d4 | ||
|
99b662f8c2 | ||
|
1411bfb989 | ||
|
1b0aa06204 | ||
|
57b3495823 | ||
|
9aaa12e7ac | ||
|
3f385991b0 | ||
|
f3b5c69441 | ||
|
c5fff92b48 | ||
|
1cebccc72b | ||
|
59922f9bc1 | ||
|
cd9b15d17f | ||
|
6f4bb4f26f | ||
|
8a7749b8fb | ||
|
9a7830bd28 | ||
|
19ea85f8dc | ||
|
30395b09f4 | ||
|
4c3f8a70a1 | ||
|
155f9c98e2 | ||
|
136bcc8128 | ||
|
8deeaca4ff | ||
|
b6bb1d5160 | ||
|
84bc3d7b7d | ||
|
730fa00e20 | ||
|
9c739651cd | ||
|
01a515dea2 | ||
|
8dcc7d3f6b | ||
|
0d06aed02d | ||
|
7a48a84784 | ||
|
6e127dcc96 | ||
|
b2b9c42724 | ||
|
977534bcb8 | ||
|
952b450a3b | ||
|
c6d5039cd7 | ||
|
7830de1566 | ||
|
6d06473cf4 | ||
|
cb3ae30284 | ||
|
f852190060 | ||
|
2ca5980634 | ||
|
689b1abbf6 | ||
|
82d19d7723 | ||
|
a379d5536b | ||
|
21267f3ca3 | ||
|
8094ecfc9e | ||
|
133015f408 | ||
|
a64d407d64 | ||
|
1768c00b9f | ||
|
f8a5b381fe | ||
|
e11f5f1c38 | ||
|
29b8d19cdf | ||
|
dd47a3dac4 | ||
|
215ed3ad52 | ||
|
47447ef017 | ||
|
22fb1be588 | ||
|
9ab9937414 | ||
|
7451041ecd | ||
|
f7f61876cf | ||
|
34f7dcfd80 | ||
|
2b19d671b4 | ||
|
53aec27328 | ||
|
0b95693fb8 | ||
|
3d7f4f41bb | ||
|
f15e808d4c | ||
|
922732b255 | ||
|
583d37a2f8 | ||
|
fd2e06316d | ||
|
bab02ff2bc | ||
|
4b49c50f4c | ||
|
3905f854ed | ||
|
17ed42be3a | ||
|
9256d7c38c | ||
|
26614057a7 | ||
|
5d85a958c9 | ||
|
93d2b9fe9c | ||
|
8642250602 | ||
|
5ad39dd3c3 | ||
|
a895029424 | ||
|
e7e3aa6cac |
@ -2,3 +2,6 @@ aml
|
|||||||
target
|
target
|
||||||
server/transformers
|
server/transformers
|
||||||
server/flash-attention
|
server/flash-attention
|
||||||
|
cmake-build-debug/
|
||||||
|
cmake-build-release/
|
||||||
|
Dockerfile*
|
||||||
|
4
.github/workflows/autodocs.yaml
vendored
4
.github/workflows/autodocs.yaml
vendored
@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install router
|
- name: Install router
|
||||||
id: install-router
|
id: install-router
|
||||||
run: cargo install --path router/
|
run: cargo install --path backends/v3/
|
||||||
|
|
||||||
- uses: actions/setup-node@v4
|
- uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
@ -41,5 +41,5 @@ jobs:
|
|||||||
|
|
||||||
- name: Check that documentation is up-to-date
|
- name: Check that documentation is up-to-date
|
||||||
run: |
|
run: |
|
||||||
npm install -g swagger-cli
|
npm install -g @redocly/cli
|
||||||
python update_doc.py --check
|
python update_doc.py --check
|
||||||
|
226
.github/workflows/build.yaml
vendored
226
.github/workflows/build.yaml
vendored
@ -6,10 +6,11 @@ on:
|
|||||||
hardware:
|
hardware:
|
||||||
type: string
|
type: string
|
||||||
description: Hardware
|
description: Hardware
|
||||||
# options:
|
# options:
|
||||||
# - cuda
|
# - cuda
|
||||||
# - rocm
|
# - cuda-trtllm
|
||||||
# - intel
|
# - rocm
|
||||||
|
# - intel
|
||||||
required: true
|
required: true
|
||||||
release-tests:
|
release-tests:
|
||||||
description: "Run release integration tests"
|
description: "Run release integration tests"
|
||||||
@ -21,68 +22,141 @@ jobs:
|
|||||||
build-and-push:
|
build-and-push:
|
||||||
outputs:
|
outputs:
|
||||||
docker_image: ${{ steps.final.outputs.docker_image }}
|
docker_image: ${{ steps.final.outputs.docker_image }}
|
||||||
|
docker_volume: ${{ steps.final.outputs.docker_volume }}
|
||||||
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
||||||
runs_on: ${{ steps.final.outputs.runs_on }}
|
runs_on: ${{ steps.final.outputs.runs_on }}
|
||||||
label: ${{ steps.final.outputs.label }}
|
label_extension: ${{ steps.final.outputs.label_extension }}
|
||||||
|
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-highmemory-32-plus-priv
|
group: aws-highmemory-64-plus-priv
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
# This is used to complete the identity challenge
|
|
||||||
# with sigstore/fulcio when running outside of PRs.
|
|
||||||
id-token: write
|
id-token: write
|
||||||
security-events: write
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Inject slug/short variables
|
- name: Inject slug/short variables
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
- name: Construct harware variables
|
- name: Inject required variables for sccache to interact with Github Actions Cache
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
|
||||||
|
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||||
|
|
||||||
|
- name: Extract TensorRT-LLM version
|
||||||
|
run: |
|
||||||
|
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
|
||||||
|
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
|
||||||
|
- name: Construct hardware variables
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
case ${{ inputs.hardware }} in
|
case ${{ inputs.hardware }} in
|
||||||
cuda)
|
cuda)
|
||||||
export dockerfile="Dockerfile"
|
export dockerfile="Dockerfile"
|
||||||
export label_extension=""
|
export label_extension=""
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
export docker_devices=""
|
export docker_devices=""
|
||||||
export runs_on="aws-g5-12xlarge-plus"
|
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||||
|
export platform=""
|
||||||
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
cuda-trtllm)
|
||||||
|
export dockerfile="Dockerfile_trtllm"
|
||||||
|
export label_extension="-trtllm"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export docker_devices=""
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
|
export platform=""
|
||||||
|
export extra_pytest=""
|
||||||
|
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
|
||||||
|
export build_type="release";
|
||||||
|
export target="";
|
||||||
|
else
|
||||||
|
export build_type="dev";
|
||||||
|
export target="ci-runtime";
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
export dockerfile="Dockerfile_amd"
|
export dockerfile="Dockerfile_amd"
|
||||||
export label_extension="-rocm"
|
export label_extension="-rocm"
|
||||||
export docker_devices="/dev/kfd,/dev/dri"
|
export docker_devices="/dev/kfd,/dev/dri"
|
||||||
# TODO Re-enable when they pass.
|
export docker_volume="/mnt"
|
||||||
# export runs_on="amd-gpu-tgi"
|
# This runner was deactivated.
|
||||||
export runs_on="ubuntu-latest"
|
export runs_on="ubuntu-latest"
|
||||||
|
export platform=""
|
||||||
|
export extra_pytest="-k test_flash_gemma_gptq_load"
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
intel)
|
intel-xpu)
|
||||||
export dockerfile="Dockerfile_intel"
|
export dockerfile="Dockerfile_intel"
|
||||||
export label_extension="-intel"
|
export label_extension="-intel-xpu"
|
||||||
|
export docker_devices=""
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
|
export platform="xpu"
|
||||||
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
intel-cpu)
|
||||||
|
export dockerfile="Dockerfile_intel"
|
||||||
|
export label_extension="-intel-cpu"
|
||||||
|
export docker_devices="none"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
# export runs_on="ubuntu-latest"
|
||||||
|
export runs_on="aws-highmemory-32-plus-priv"
|
||||||
|
export platform="cpu"
|
||||||
|
export extra_pytest="-k test_flash_gemma_simple"
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
neuron)
|
||||||
|
export dockerfile="Dockerfile.neuron"
|
||||||
|
export label_extension="-neuron"
|
||||||
|
export docker_devices="/dev/neuron0"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
|
export runs_on="aws-inf2-8xlarge"
|
||||||
|
export platform="cpu"
|
||||||
|
export extra_pytest="--neuron"
|
||||||
|
export target=""
|
||||||
|
;;
|
||||||
|
gaudi)
|
||||||
|
export dockerfile="Dockerfile_gaudi"
|
||||||
|
export label_extension="-gaudi"
|
||||||
|
export docker_volume="/mnt/cache"
|
||||||
export docker_devices=""
|
export docker_devices=""
|
||||||
export runs_on="ubuntu-latest"
|
export runs_on="ubuntu-latest"
|
||||||
;;
|
export platform=""
|
||||||
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
esac
|
esac
|
||||||
echo $dockerfile
|
echo $dockerfile
|
||||||
echo "Dockerfile=${dockerfile}"
|
echo "Dockerfile=${dockerfile}"
|
||||||
echo $label_extension
|
echo $label_extension
|
||||||
echo $docker_devices
|
echo $docker_devices
|
||||||
echo $runs_on
|
echo $runs_on
|
||||||
|
echo $platform
|
||||||
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
||||||
echo "LABEL=${label_extension}" >> $GITHUB_ENV
|
echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
|
||||||
|
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
||||||
|
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
|
||||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||||
|
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
||||||
|
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||||
|
echo "TARGET=${target}" >> $GITHUB_ENV
|
||||||
|
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
|
||||||
- name: Initialize Docker Buildx
|
- name: Initialize Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
with:
|
||||||
install: true
|
install: true
|
||||||
buildkitd-config-inline: |
|
buildkitd-config: /tmp/buildkitd.toml
|
||||||
[registry."docker.io"]
|
|
||||||
mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"]
|
|
||||||
- name: Login to internal Container Registry
|
- name: Login to internal Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.REGISTRY_USERNAME }}
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
@ -95,6 +169,12 @@ jobs:
|
|||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Login to Docker Hub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: docker.io
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
- name: Login to Azure Container Registry
|
- name: Login to Azure Container Registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
@ -109,10 +189,9 @@ jobs:
|
|||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
|
docker.io/huggingface/text-generation-inference-ci
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
# If main, release or tag
|
# If main, release or tag
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
if: ${{ github.event_name != 'pull_request' }}
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
@ -120,17 +199,16 @@ jobs:
|
|||||||
uses: docker/metadata-action@v4.3.0
|
uses: docker/metadata-action@v4.3.0
|
||||||
with:
|
with:
|
||||||
flavor: |
|
flavor: |
|
||||||
latest=auto
|
latest=false
|
||||||
images: |
|
images: |
|
||||||
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca
|
|
||||||
ghcr.io/huggingface/text-generation-inference
|
ghcr.io/huggingface/text-generation-inference
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=semver,pattern={{version}}${{ env.LABEL }}
|
type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
|
||||||
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }}
|
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
|
||||||
type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
id: build-and-push
|
id: build-and-push
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
@ -141,28 +219,41 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
|
||||||
|
PLATFORM=${{ env.PLATFORM }}
|
||||||
|
build_type=${{ env.BUILD_TYPE }}
|
||||||
|
sccache_gha_enabled=on
|
||||||
|
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
|
||||||
|
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
||||||
|
target: ${{ env.TARGET }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||||
integration_tests:
|
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
|
||||||
|
precompile_neuron_models:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: build-and-push
|
needs: build-and-push
|
||||||
|
if: needs.build-and-push.outputs.label_extension == '-neuron'
|
||||||
runs-on:
|
runs-on:
|
||||||
group: ${{ needs.build-and-push.outputs.runs_on }}
|
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
|
||||||
env:
|
env:
|
||||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -171,15 +262,66 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
|
- name: Install
|
||||||
|
run: |
|
||||||
|
make install-integration-tests
|
||||||
|
- name: Export neuron models
|
||||||
|
run: |
|
||||||
|
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||||
|
echo $DOCKER_IMAGE
|
||||||
|
docker pull $DOCKER_IMAGE
|
||||||
|
export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
|
||||||
|
python integration-tests/fixtures/neuron/export_models.py
|
||||||
|
integration_tests:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs: [precompile_neuron_models, build-and-push]
|
||||||
|
if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
|
||||||
|
runs-on:
|
||||||
|
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||||
|
env:
|
||||||
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
|
||||||
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||||
|
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
|
||||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
docker pull $DOCKER_IMAGE
|
||||||
|
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
||||||
|
|
||||||
|
backend_trtllm_cxx_tests:
|
||||||
|
needs: build-and-push
|
||||||
|
if: needs.build-and-push.outputs.label_extension == '-trtllm'
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
runs-on:
|
||||||
|
group: aws-g6-12xl-plus-priv-cache
|
||||||
|
container:
|
||||||
|
image: ${{ needs.build-and-push.outputs.docker_image }}
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
options: --gpus all --shm-size=8g
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Run C++/CUDA tests
|
||||||
|
if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
|
||||||
|
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
|
||||||
|
@ -11,7 +11,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yaml@main
|
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||||
with:
|
with:
|
||||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||||
pr_number: ${{ github.event.number }}
|
pr_number: ${{ github.event.number }}
|
||||||
|
9
.github/workflows/ci_build.yaml
vendored
9
.github/workflows/ci_build.yaml
vendored
@ -10,6 +10,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/build.yaml"
|
- ".github/workflows/build.yaml"
|
||||||
- "integration-tests/**"
|
- "integration-tests/**"
|
||||||
|
- "backends/**"
|
||||||
- "server/**"
|
- "server/**"
|
||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
@ -19,6 +20,8 @@ on:
|
|||||||
- "Dockerfile"
|
- "Dockerfile"
|
||||||
- "Dockerfile_amd"
|
- "Dockerfile_amd"
|
||||||
- "Dockerfile_intel"
|
- "Dockerfile_intel"
|
||||||
|
- "Dockerfile.neuron"
|
||||||
|
- "Dockerfile_gaudi"
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@ -36,8 +39,12 @@ jobs:
|
|||||||
# fail-fast is true by default
|
# fail-fast is true by default
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
hardware: ["cuda", "rocm", "intel"]
|
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
|
||||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
id-token: write
|
||||||
with:
|
with:
|
||||||
hardware: ${{ matrix.hardware }}
|
hardware: ${{ matrix.hardware }}
|
||||||
# https://github.com/actions/runner/issues/2206
|
# https://github.com/actions/runner/issues/2206
|
||||||
|
41
.github/workflows/load_test.yaml
vendored
41
.github/workflows/load_test.yaml
vendored
@ -3,12 +3,17 @@ name: Nightly load test
|
|||||||
on:
|
on:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1-5'
|
- cron: '0 0 * * 1-5'
|
||||||
|
workflow_call:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/load_test.yaml"
|
- ".github/workflows/load_test.yaml"
|
||||||
branches:
|
|
||||||
- 'main'
|
env:
|
||||||
|
AWS_DEFAULT_REGION: us-east-1
|
||||||
|
AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||||
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
load-tests:
|
load-tests:
|
||||||
@ -16,28 +21,30 @@ jobs:
|
|||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-g5-12xlarge
|
group: aws-g6-12xl-plus-priv-cache
|
||||||
env:
|
env:
|
||||||
DOCKER_VOLUME: /cache
|
DOCKER_VOLUME: /cache
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install k6
|
- name: Install Python 3.11
|
||||||
run: |
|
uses: actions/setup-python@v2
|
||||||
curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1
|
with:
|
||||||
|
python-version: 3.11
|
||||||
|
|
||||||
- name: Start starcoder
|
- name: Install poetry
|
||||||
run: |
|
run: |
|
||||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
sleep 10
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
poetry --version
|
||||||
|
|
||||||
- name: Run k6
|
- name: Run bench test
|
||||||
run: |
|
run: |
|
||||||
./k6 run load_tests/starcoder_load.js
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
|
cd load_tests
|
||||||
- name: Stop starcoder
|
poetry install
|
||||||
if: ${{ always() }}
|
poetry run python benchmarks.py --sha ${{ github.sha }} --results-file "s3://text-generation-inference-ci/benchmarks/ci/${{ github.sha }}.parquet"
|
||||||
run: |
|
shell: bash
|
||||||
docker stop tgi-starcoder || true
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN_BENCHMARK }}
|
||||||
|
53
.github/workflows/nix_build.yaml
vendored
Normal file
53
.github/workflows/nix_build.yaml
vendored
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
name: "Nix Build Docker image"
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- 'main'
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
concurrency:
|
||||||
|
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_nix_image:
|
||||||
|
runs-on:
|
||||||
|
group: aws-highmemory-32-plus-priv
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
- uses: cachix/cachix-action@v14
|
||||||
|
with:
|
||||||
|
name: text-generation-inference
|
||||||
|
# If you chose signing key for write access
|
||||||
|
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||||
|
env:
|
||||||
|
USER: github_runner
|
||||||
|
- name: Build
|
||||||
|
run: nix build .#dockerImage
|
||||||
|
- name: Initialize Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
install: true
|
||||||
|
buildkitd-config: /tmp/buildkitd.toml
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Login to internal Container Registry
|
||||||
|
# if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
registry: registry.internal.huggingface.tech
|
||||||
|
- name: Push to docker
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
|
else
|
||||||
|
export TAG=${{ github.ref_name }}-nix
|
||||||
|
fi
|
||||||
|
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
|
||||||
|
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
|
34
.github/workflows/nix_cache.yaml
vendored
Normal file
34
.github/workflows/nix_cache.yaml
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
name: "Cache devshells"
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "flake.nix"
|
||||||
|
- "flake.lock"
|
||||||
|
- "nix/**"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
runs-on:
|
||||||
|
group: aws-highmemory-32-plus-priv
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
- uses: cachix/cachix-action@v14
|
||||||
|
with:
|
||||||
|
name: text-generation-inference
|
||||||
|
# If you chose signing key for write access
|
||||||
|
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||||
|
env:
|
||||||
|
USER: github_runner
|
||||||
|
- name: Build impure devshell
|
||||||
|
run: nix build .\#devShells.x86_64-linux.impure
|
||||||
|
- name: Build impure devshell (CUDA dev)
|
||||||
|
run: nix build .\#devShells.x86_64-linux.impureWithCuda
|
||||||
|
# Pure shell dependencies are covered by Nix tests.
|
||||||
|
# - name: Build pure devshell
|
||||||
|
# run: nix build .\#devShells.x86_64-linux.pure
|
42
.github/workflows/nix_tests.yaml
vendored
Normal file
42
.github/workflows/nix_tests.yaml
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
name: "Nix Tests"
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- ".github/workflows/nix_tests.yaml"
|
||||||
|
- "server/**"
|
||||||
|
- "proto/**"
|
||||||
|
- "router/**"
|
||||||
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
|
- "Cargo.lock"
|
||||||
|
- "rust-toolchain.toml"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
runs-on:
|
||||||
|
group: aws-highmemory-32-plus-priv
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
- uses: cachix/cachix-action@v14
|
||||||
|
with:
|
||||||
|
name: text-generation-inference
|
||||||
|
# If you chose signing key for write access
|
||||||
|
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||||
|
env:
|
||||||
|
USER: github_runner
|
||||||
|
- name: Build
|
||||||
|
run: nix develop .#test --command echo "Ok"
|
||||||
|
- name: Pre-commit tests.
|
||||||
|
run: nix develop .#test --command pre-commit run --all-files
|
||||||
|
- name: Python tests.
|
||||||
|
run: nix develop .#test --command python -m pytest server/tests/
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
- name: Rust tests.
|
||||||
|
run: nix develop .#test --command cargo test
|
62
.github/workflows/tests.yaml
vendored
62
.github/workflows/tests.yaml
vendored
@ -8,6 +8,7 @@ on:
|
|||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
|
|
||||||
@ -17,26 +18,17 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_tests:
|
run_tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on:
|
||||||
|
group: aws-highmemory-32-plus-priv
|
||||||
env:
|
|
||||||
SCCACHE_GHA_ENABLED: "on"
|
|
||||||
RUSTC_WRAPPER: /usr/local/bin/sccache
|
|
||||||
SCCACHE: 0.3.3
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v1
|
uses: actions/setup-python@v4
|
||||||
|
id: python
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: 3.11
|
||||||
- name: Install Rust
|
- uses: dtolnay/rust-toolchain@1.85.0
|
||||||
uses: actions-rs/toolchain@v1
|
|
||||||
with:
|
with:
|
||||||
# Released on: 02 May, 2024
|
|
||||||
# https://releases.rs/docs/1.78.0/
|
|
||||||
toolchain: 1.79.0
|
|
||||||
override: true
|
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
uses: arduino/setup-protoc@v1
|
uses: arduino/setup-protoc@v1
|
||||||
@ -44,34 +36,22 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
||||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||||
- name: Install sccache
|
|
||||||
run: |
|
|
||||||
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
|
|
||||||
chmod +x /usr/local/bin/sccache
|
|
||||||
- name: configure sccache
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
|
|
||||||
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
|
||||||
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
|
|
||||||
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
|
|
||||||
- name: cargo registry cache
|
|
||||||
uses: actions/cache@v3
|
|
||||||
with:
|
|
||||||
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
|
|
||||||
restore-keys: |
|
|
||||||
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
|
|
||||||
cargo-${{ runner.os }}-
|
|
||||||
path: |
|
|
||||||
~/.cargo/registry
|
|
||||||
~/.cargo/git
|
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install python3.11-dev -y
|
||||||
|
pip install -U pip uv
|
||||||
|
uv venv
|
||||||
|
source ./.venv/bin/activate
|
||||||
make install-cpu
|
make install-cpu
|
||||||
|
- name: Download locked kernels
|
||||||
|
run: |
|
||||||
|
source ./.venv/bin/activate
|
||||||
|
kernels download server
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
source ./.venv/bin/activate
|
||||||
|
uv pip install pytest
|
||||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
@ -82,6 +62,6 @@ jobs:
|
|||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
- name: sccache stats
|
- name: Run Rust tests with google feature
|
||||||
run: |
|
run: |
|
||||||
/usr/local/bin/sccache --show-stats
|
cargo test --features google
|
||||||
|
15
.github/workflows/trufflehog.yaml
vendored
15
.github/workflows/trufflehog.yaml
vendored
@ -10,9 +10,12 @@ jobs:
|
|||||||
trufflehog:
|
trufflehog:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Secret Scanning
|
- name: Secret Scanning
|
||||||
uses: trufflesecurity/trufflehog@main
|
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
||||||
|
with:
|
||||||
|
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
||||||
|
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
||||||
|
17
.gitignore
vendored
17
.gitignore
vendored
@ -3,9 +3,14 @@ target
|
|||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
backends/v2/src/client/pb
|
||||||
|
backends/v3/src/client/pb
|
||||||
|
backends/client/src/v2/pb
|
||||||
|
backends/client/src/v3/pb
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
server/exllamav2
|
||||||
server/exllama_kernels/exllama_kernels/hip/
|
server/exllama_kernels/exllama_kernels/hip/
|
||||||
server/exllama_kernels/exllama_kernels/hip_func/
|
server/exllama_kernels/exllama_kernels/hip_func/
|
||||||
*_hip.cuh
|
*_hip.cuh
|
||||||
@ -14,3 +19,13 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
|||||||
|
|
||||||
data/
|
data/
|
||||||
load_tests/*.json
|
load_tests/*.json
|
||||||
|
server/fbgemmm
|
||||||
|
|
||||||
|
.direnv/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# Gaudi auto-generated files
|
||||||
|
hl-smi_log*.txt
|
||||||
|
.graph_dumps
|
||||||
|
out
|
||||||
|
hqt_output
|
||||||
|
@ -4,8 +4,9 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
exclude: crate-hashes.json
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: docs/source/basic_tutorials/launcher.md
|
exclude: docs/source/reference/launcher.md
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 24.2.0
|
rev: 24.2.0
|
||||||
hooks:
|
hooks:
|
||||||
@ -13,6 +14,11 @@ repos:
|
|||||||
- repo: https://github.com/doublify/pre-commit-rust
|
- repo: https://github.com/doublify/pre-commit-rust
|
||||||
rev: v1.0
|
rev: v1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: fmt
|
|
||||||
- id: cargo-check
|
- id: cargo-check
|
||||||
|
- id: fmt
|
||||||
- id: clippy
|
- id: clippy
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.3.0
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
82
.redocly.lint-ignore.yaml
Normal file
82
.redocly.lint-ignore.yaml
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
|
||||||
|
# See https://redoc.ly/docs/cli/ for more information.
|
||||||
|
docs/openapi.json:
|
||||||
|
no-empty-servers:
|
||||||
|
- '#/openapi'
|
||||||
|
spec:
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||||
|
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||||
|
- '#/components/schemas/ToolChoice/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||||
|
no-invalid-media-type-examples:
|
||||||
|
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/424/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/429/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/500/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
|
||||||
|
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
|
||||||
|
operation-4xx-response:
|
||||||
|
- '#/paths/~1health/get/responses'
|
||||||
|
- '#/paths/~1info/get/responses'
|
||||||
|
- '#/paths/~1metrics/get/responses'
|
||||||
|
no-unused-components:
|
||||||
|
- '#/components/schemas/Completion'
|
||||||
|
security-defined:
|
||||||
|
- '#/paths/~1/post'
|
||||||
|
- '#/paths/~1generate/post'
|
||||||
|
- '#/paths/~1generate_stream/post'
|
||||||
|
- '#/paths/~1health/get'
|
||||||
|
- '#/paths/~1info/get'
|
||||||
|
- '#/paths/~1metrics/get'
|
||||||
|
- '#/paths/~1tokenize/post'
|
||||||
|
- '#/paths/~1v1~1chat~1completions/post'
|
||||||
|
- '#/paths/~1v1~1completions/post'
|
||||||
|
- '#/paths/~1v1~1models/get'
|
2982
Cargo.lock
generated
2982
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
31
Cargo.toml
31
Cargo.toml
@ -1,23 +1,40 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
"router",
|
"backends/v2",
|
||||||
"router/client",
|
"backends/v3",
|
||||||
"router/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"launcher"
|
"backends/trtllm",
|
||||||
|
"backends/llamacpp",
|
||||||
|
"launcher",
|
||||||
|
"router"
|
||||||
|
]
|
||||||
|
default-members = [
|
||||||
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
|
"backends/v3",
|
||||||
|
"backends/grpc-metadata",
|
||||||
|
# "backends/trtllm",
|
||||||
|
"launcher",
|
||||||
|
"router"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.2.0"
|
version = "3.2.3-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
base64 = "0.22.0"
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.4.2", features = ["tokio"] }
|
||||||
|
metrics = { version = "0.23.0" }
|
||||||
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
|
minijinja = { version = "2.2.0", features = ["json"] }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
|
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
208
Dockerfile
208
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -11,11 +11,15 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
|
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
python3.11-dev
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -28,32 +32,29 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
|||||||
ARG GIT_SHA
|
ARG GIT_SHA
|
||||||
ARG DOCKER_LABEL
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt --frozen
|
||||||
|
|
||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
||||||
|
WORKDIR /usr/src/
|
||||||
|
|
||||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||||
ARG PYTORCH_VERSION=2.3.0
|
ARG PYTORCH_VERSION=2.6
|
||||||
|
ARG PYTHON_VERSION=3.11
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.10
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG CUDA_VERSION=12.1
|
|
||||||
ARG MAMBA_VERSION=24.3.0-0
|
|
||||||
ARG CUDA_CHANNEL=nvidia
|
|
||||||
ARG INSTALL_CHANNEL=pytorch
|
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
@ -61,31 +62,18 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
git && \
|
git && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
# Install conda
|
ENV PATH="$PATH:/root/.local/bin"
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
RUN case ${TARGETPLATFORM} in \
|
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
esac && \
|
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
||||||
rm ~/mambaforge.sh
|
|
||||||
|
|
||||||
# Install pytorch
|
|
||||||
# On arm64 we exit with an error code
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") exit 1 ;; \
|
|
||||||
*) /opt/conda/bin/conda update -y conda && \
|
|
||||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
|
||||||
esac && \
|
|
||||||
/opt/conda/bin/conda clean -ya
|
|
||||||
|
|
||||||
# CUDA kernels builder image
|
# CUDA kernels builder image
|
||||||
FROM pytorch-install AS kernel-builder
|
FROM pytorch-install AS kernel-builder
|
||||||
|
|
||||||
ARG MAX_JOBS=8
|
ARG MAX_JOBS=8
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
ninja-build cmake \
|
ninja-build cmake \
|
||||||
@ -99,7 +87,7 @@ WORKDIR /usr/src
|
|||||||
COPY server/Makefile-flash-att Makefile
|
COPY server/Makefile-flash-att Makefile
|
||||||
|
|
||||||
# Build specific version of flash attention
|
# Build specific version of flash attention
|
||||||
RUN make build-flash-attention
|
RUN . .venv/bin/activate && make build-flash-attention
|
||||||
|
|
||||||
# Build Flash Attention v2 CUDA kernels
|
# Build Flash Attention v2 CUDA kernels
|
||||||
FROM kernel-builder AS flash-att-v2-builder
|
FROM kernel-builder AS flash-att-v2-builder
|
||||||
@ -109,96 +97,61 @@ WORKDIR /usr/src
|
|||||||
COPY server/Makefile-flash-att-v2 Makefile
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
# Build specific version of flash attention v2
|
||||||
RUN make build-flash-attention-v2-cuda
|
RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder AS exllama-kernels-builder
|
FROM kernel-builder AS exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
RUN . .venv/bin/activate && python setup.py build
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
# Build Transformers exllama kernels
|
||||||
FROM kernel-builder AS exllamav2-kernels-builder
|
FROM kernel-builder AS exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/Makefile-exllamav2/ Makefile
|
||||||
|
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
RUN . .venv/bin/activate && make build-exllamav2
|
||||||
|
|
||||||
# Build Transformers awq kernels
|
# Build Transformers awq kernels
|
||||||
FROM kernel-builder AS awq-kernels-builder
|
FROM kernel-builder AS awq-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-awq Makefile
|
COPY server/Makefile-awq Makefile
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
|
RUN . .venv/bin/activate && make build-awq
|
||||||
|
|
||||||
# Build eetq kernels
|
|
||||||
FROM kernel-builder AS eetq-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/Makefile-eetq Makefile
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
|
||||||
|
|
||||||
# Build marlin kernels
|
|
||||||
FROM kernel-builder AS marlin-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/marlin/ .
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
|
||||||
|
|
||||||
# Build Lorax Punica kernels
|
# Build Lorax Punica kernels
|
||||||
FROM kernel-builder AS lorax-punica-builder
|
FROM kernel-builder AS lorax-punica-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-lorax-punica Makefile
|
COPY server/Makefile-lorax-punica Makefile
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder AS custom-kernels-builder
|
FROM kernel-builder AS custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN . .venv/bin/activate && python setup.py build
|
||||||
|
|
||||||
# Build FBGEMM CUDA kernels
|
|
||||||
FROM kernel-builder AS fbgemm-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-fbgemm Makefile
|
|
||||||
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
|
||||||
COPY server/fix_torch90a.sh fix_torch90a.sh
|
|
||||||
|
|
||||||
RUN make build-fbgemm
|
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
|
||||||
FROM kernel-builder AS vllm-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
|
||||||
|
|
||||||
# Build specific version of vllm
|
|
||||||
RUN make build-vllm-cuda
|
|
||||||
|
|
||||||
# Build mamba kernels
|
# Build mamba kernels
|
||||||
FROM kernel-builder AS mamba-builder
|
FROM kernel-builder AS mamba-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/Makefile-selective-scan Makefile
|
COPY server/Makefile-selective-scan Makefile
|
||||||
RUN make build-all
|
RUN . .venv/bin/activate && make build-all
|
||||||
|
|
||||||
|
# Build flashinfer
|
||||||
|
FROM kernel-builder AS flashinfer-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/Makefile-flashinfer Makefile
|
||||||
|
RUN . .venv/bin/activate && make install-flashinfer
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
|
||||||
|
|
||||||
# Conda env
|
|
||||||
ENV PATH=/opt/conda/bin:$PATH \
|
|
||||||
CONDA_PREFIX=/opt/conda
|
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HF_HOME=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
@ -212,52 +165,64 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
git \
|
git \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy conda with PyTorch installed
|
# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
# ENV PATH="$PATH:/root/.local/bin"
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
# Copy build artifacts from flash attention builder
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
|
||||||
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from awq kernels builder
|
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from eetq kernels builder
|
|
||||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from marlin kernels builder
|
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from fbgemm builder
|
|
||||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from vllm builder
|
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
# Copy build artifacts from mamba builder
|
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
# RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
# Copy env with PyTorch installed
|
||||||
|
COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
|
||||||
|
ENV PYTHON_VERSION=3.11
|
||||||
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV HF_KERNELS_CACHE=/kernels
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
|
||||||
pip install -r requirements_cuda.txt && \
|
make gen-server-raw && \
|
||||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
|
kernels download .
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
RUN cd server && \
|
||||||
|
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
|
||||||
|
uv pip install nvidia-nccl-cu12==2.25.1 && \
|
||||||
|
pwd && \
|
||||||
|
text-generation-server --help
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention builder
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention v2 builder
|
||||||
|
COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from custom kernels builder
|
||||||
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from exllama kernels builder
|
||||||
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
|
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from awq kernels builder
|
||||||
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from lorax punica kernels builder
|
||||||
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
# Copy build artifacts from mamba builder
|
||||||
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
|
||||||
|
COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
|
||||||
|
|
||||||
|
|
||||||
|
# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
# Required to find libpython within the rust binaries
|
||||||
|
# This is needed because exl2 tries to load flash-attn
|
||||||
|
# And fails with our builds.
|
||||||
|
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||||
|
|
||||||
# Deps before the binaries
|
# Deps before the binaries
|
||||||
# The binaries change on every build given we burn the SHA into them
|
# The binaries change on every build given we burn the SHA into them
|
||||||
@ -289,5 +254,6 @@ FROM base
|
|||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
# CMD ["--json-output"]
|
# CMD ["--json-output"]
|
||||||
|
167
Dockerfile.neuron
Normal file
167
Dockerfile.neuron
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
# Fetch and extract the TGI sources
|
||||||
|
FROM alpine AS tgi
|
||||||
|
RUN mkdir -p /tgi
|
||||||
|
|
||||||
|
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
||||||
|
FROM alpine AS optimum-neuron
|
||||||
|
RUN mkdir -p /optimum-neuron
|
||||||
|
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||||
|
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
||||||
|
|
||||||
|
# Build cargo components (adapted from TGI original Dockerfile)
|
||||||
|
# Note: we cannot use the cargo-chef base image as it uses python 3.11
|
||||||
|
FROM ubuntu:22.04 AS chef
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
curl ca-certificates build-essential \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY backends/neuron/Cargo.toml Cargo.toml
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
unzip python3-dev libssl-dev pkg-config \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY backends/neuron/Cargo.toml Cargo.toml
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
# Python base image
|
||||||
|
FROM ubuntu:22.04 AS base
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
python3-pip \
|
||||||
|
python3-setuptools \
|
||||||
|
python-is-python3 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
RUN pip3 --no-cache-dir install --upgrade pip
|
||||||
|
|
||||||
|
# Python server build image
|
||||||
|
FROM base AS pyserver
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
make \
|
||||||
|
python3-venv \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN install -d /pyserver
|
||||||
|
WORKDIR /pyserver
|
||||||
|
COPY backends/neuron/server server
|
||||||
|
COPY proto proto
|
||||||
|
RUN pip3 install -r server/build-requirements.txt
|
||||||
|
RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
|
||||||
|
|
||||||
|
# Neuron base image (used for deployment)
|
||||||
|
FROM base AS neuron
|
||||||
|
|
||||||
|
# Install system prerequisites
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
gnupg2 \
|
||||||
|
wget \
|
||||||
|
python3-dev \
|
||||||
|
libexpat1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
|
||||||
|
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
|
||||||
|
|
||||||
|
# Install neuronx packages
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
aws-neuronx-dkms=2.19.64.0 \
|
||||||
|
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
|
||||||
|
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
|
||||||
|
aws-neuronx-tools=2.20.204.0 \
|
||||||
|
libxml2 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
|
||||||
|
|
||||||
|
# Install manually torch CPU version to avoid pulling CUDA
|
||||||
|
RUN pip3 install \
|
||||||
|
torch==2.5.1 \
|
||||||
|
torchvision==0.20.1 \
|
||||||
|
--index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
RUN pip3 install \
|
||||||
|
neuronx-cc==2.16.372.0 \
|
||||||
|
torch-neuronx==2.5.1.2.4.0 \
|
||||||
|
transformers-neuronx==0.13.322 \
|
||||||
|
neuronx-distributed==0.10.1 \
|
||||||
|
libneuronxla==2.1.681.0 \
|
||||||
|
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
||||||
|
|
||||||
|
# Install HuggingFace packages
|
||||||
|
RUN pip3 install \
|
||||||
|
hf_transfer huggingface_hub
|
||||||
|
|
||||||
|
# Install optimum-neuron
|
||||||
|
COPY --from=optimum-neuron /optimum-neuron optimum-neuron
|
||||||
|
RUN pip3 install ./optimum-neuron
|
||||||
|
|
||||||
|
# TGI base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/tmp \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Disable color logs as they are not supported by CloudWatch
|
||||||
|
ENV LOGURU_COLORIZE=NO
|
||||||
|
ENV LOG_COLORIZE=0
|
||||||
|
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
# Install python server
|
||||||
|
COPY --from=pyserver /pyserver/build/dist dist
|
||||||
|
RUN pip install dist/text_generation_server*.tar.gz
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM neuron
|
||||||
|
|
||||||
|
COPY backends/neuron/tgi_env.py /tgi_env.py
|
||||||
|
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
24
Dockerfile.nix
Normal file
24
Dockerfile.nix
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Build the image and get out the docker file:
|
||||||
|
#
|
||||||
|
# docker build -t tgi-nix-builder -f Dockerfile.nix
|
||||||
|
# docker run --log-driver=none tgi-nix-builder | docker load
|
||||||
|
|
||||||
|
FROM nixos/nix:2.18.8 AS builder
|
||||||
|
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||||
|
RUN nix profile install nixpkgs#cachix
|
||||||
|
RUN cachix use text-generation-inference
|
||||||
|
WORKDIR /root
|
||||||
|
ADD . .
|
||||||
|
RUN nix build .
|
||||||
|
RUN mkdir /tmp/nix-store-closure
|
||||||
|
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy /nix/store
|
||||||
|
COPY --from=builder /tmp/nix-store-closure /nix/store
|
||||||
|
COPY --from=builder /root/result /app
|
||||||
|
RUN ldconfig
|
||||||
|
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
|
326
Dockerfile_amd
326
Dockerfile_amd
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -11,11 +11,14 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
python3.11-dev
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -28,170 +31,247 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
|||||||
ARG GIT_SHA
|
ARG GIT_SHA
|
||||||
ARG DOCKER_LABEL
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt --frozen
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
|
||||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
|
||||||
|
|
||||||
|
ARG HIPBLASLT_BRANCH="4d40e36"
|
||||||
|
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||||
|
ARG LEGACY_HIPBLASLT_OPTION=
|
||||||
|
ARG RCCL_BRANCH="648a58d"
|
||||||
|
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||||
|
ARG TRITON_BRANCH="e5be006"
|
||||||
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||||
|
ARG PYTORCH_BRANCH="3a585126"
|
||||||
|
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||||
|
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||||
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||||
|
ARG FA_BRANCH="b7d29fb"
|
||||||
|
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||||
|
ARG AITER_BRANCH="21d47a9"
|
||||||
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
|
|
||||||
|
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||||
|
ENV ROCM_PATH=/opt/rocm
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||||
|
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
||||||
|
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||||
|
|
||||||
|
ARG PYTHON_VERSION=3.11
|
||||||
|
|
||||||
|
RUN mkdir -p /app
|
||||||
|
WORKDIR /app
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
ccache \
|
ccache \
|
||||||
curl \
|
curl \
|
||||||
git \
|
git \
|
||||||
make \
|
ninja-build \
|
||||||
libssl-dev \
|
cmake \
|
||||||
g++ \
|
software-properties-common \
|
||||||
# Needed to build VLLM & flash.
|
python3.11-dev \
|
||||||
rocthrust-dev \
|
python3.11-venv && \
|
||||||
hipsparse-dev \
|
rm -rf /var/lib/apt/lists/*
|
||||||
hipblas-dev \
|
|
||||||
hipblaslt-dev \
|
|
||||||
rocblas-dev \
|
|
||||||
hiprand-dev \
|
|
||||||
rocrand-dev \
|
|
||||||
miopen-hip-dev \
|
|
||||||
hipfft-dev \
|
|
||||||
hipcub-dev \
|
|
||||||
hipsolver-dev \
|
|
||||||
rccl-dev \
|
|
||||||
cmake \
|
|
||||||
python3-dev && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ENV PATH="$PATH:/root/.local/bin"
|
||||||
ARG PYTORCH_VERSION='2.3.0'
|
RUN uv python install ${PYTHON_VERSION}
|
||||||
ARG ROCM_VERSION='6.0.2'
|
RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
|
||||||
ARG PYTHON_VERSION='3.10.10'
|
ENV VIRTUAL_ENV=/usr/src/.venv/
|
||||||
# Automatically set by buildx
|
ENV PATH="$PATH:/usr/src/.venv/bin/"
|
||||||
ARG TARGETPLATFORM
|
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
||||||
# Install mamba
|
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
|
||||||
esac && \
|
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
||||||
mamba init && \
|
|
||||||
rm ~/mambaforge.sh
|
|
||||||
|
|
||||||
# Install flash-attention, torch dependencies
|
FROM base AS build_hipblaslt
|
||||||
RUN pip install numpy einops ninja --no-cache-dir
|
ARG HIPBLASLT_BRANCH
|
||||||
|
ARG HIPBLAS_COMMON_BRANCH
|
||||||
|
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||||
|
ARG LEGACY_HIPBLASLT_OPTION
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||||
|
RUN . .venv/bin/activate && cd hipBLAS-common \
|
||||||
|
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||||
|
&& mkdir build \
|
||||||
|
&& cd build \
|
||||||
|
&& cmake .. \
|
||||||
|
&& make package \
|
||||||
|
&& dpkg -i ./*.deb
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||||
|
RUN . .venv/bin/activate && cd hipBLASLt \
|
||||||
|
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||||
|
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||||
|
&& cd build/release \
|
||||||
|
&& make package
|
||||||
|
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||||
|
|
||||||
RUN conda install intel::mkl-static intel::mkl-include
|
FROM base AS build_rccl
|
||||||
RUN pip uninstall -y triton && \
|
ARG RCCL_BRANCH
|
||||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
ARG RCCL_REPO
|
||||||
cd triton/python && \
|
RUN git clone ${RCCL_REPO}
|
||||||
pip install .
|
RUN . .venv/bin/activate && cd rccl \
|
||||||
|
&& git checkout ${RCCL_BRANCH} \
|
||||||
|
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||||
|
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||||
|
|
||||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
FROM base AS build_triton
|
||||||
|
ARG TRITON_BRANCH
|
||||||
|
ARG TRITON_REPO
|
||||||
|
RUN git clone ${TRITON_REPO}
|
||||||
|
RUN . .venv/bin/activate && cd triton \
|
||||||
|
&& git checkout ${TRITON_BRANCH} \
|
||||||
|
&& cd python \
|
||||||
|
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||||
|
|
||||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
FROM base AS build_amdsmi
|
||||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
&& pip wheel . --wheel-dir=dist
|
||||||
ARG BUILD_CAFFE2="0" \
|
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||||
BUILD_CAFFE2_OPS="0" \
|
|
||||||
USE_CUDA="0" \
|
|
||||||
USE_ROCM="1" \
|
|
||||||
BUILD_TEST="0" \
|
|
||||||
USE_FBGEMM="0" \
|
|
||||||
USE_NNPACK="0" \
|
|
||||||
USE_QNNPACK="0" \
|
|
||||||
USE_XNNPACK="0" \
|
|
||||||
USE_FLASH_ATTENTION="1" \
|
|
||||||
USE_MEM_EFF_ATTENTION="0"
|
|
||||||
|
|
||||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
FROM base AS build_pytorch
|
||||||
|
ARG PYTORCH_BRANCH
|
||||||
|
ARG PYTORCH_VISION_BRANCH
|
||||||
|
ARG PYTORCH_REPO
|
||||||
|
ARG PYTORCH_VISION_REPO
|
||||||
|
ARG FA_BRANCH
|
||||||
|
ARG FA_REPO
|
||||||
|
RUN git clone ${PYTORCH_REPO} pytorch
|
||||||
|
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
||||||
|
pip install -r requirements.txt && git submodule update --init --recursive \
|
||||||
|
&& python3 tools/amd_build/build_amd.py \
|
||||||
|
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
||||||
|
&& pip install dist/*.whl
|
||||||
|
RUN git clone ${PYTORCH_VISION_REPO} vision
|
||||||
|
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
||||||
|
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||||
|
&& pip install dist/*.whl
|
||||||
|
RUN git clone ${FA_REPO}
|
||||||
|
RUN . .venv/bin/activate && cd flash-attention \
|
||||||
|
&& git checkout ${FA_BRANCH} \
|
||||||
|
&& git submodule update --init \
|
||||||
|
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
||||||
|
&& cp /app/vision/dist/*.whl /app/install \
|
||||||
|
&& cp /app/flash-attention/dist/*.whl /app/install
|
||||||
|
|
||||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
FROM base AS final
|
||||||
ENV HIP_FORCE_DEV_KERNARG=1
|
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||||
|
dpkg -i /install/*deb \
|
||||||
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||||
|
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||||
|
dpkg -i /install/*deb \
|
||||||
|
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||||
|
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||||
|
. .venv/bin/activate && \
|
||||||
|
pip install /install/*.whl
|
||||||
|
|
||||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
ARG AITER_REPO
|
||||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
ARG AITER_BRANCH
|
||||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
RUN git clone --recursive ${AITER_REPO}
|
||||||
|
RUN . .venv/bin/activate && cd aiter \
|
||||||
|
&& git checkout ${AITER_BRANCH} \
|
||||||
|
&& git submodule update --init --recursive \
|
||||||
|
&& pip install -r requirements.txt \
|
||||||
|
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
RUN rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
FROM final AS kernel-builder
|
||||||
# # Build vllm kernels
|
# # Build vllm kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
COPY server/Makefile-vllm Makefile
|
||||||
|
RUN . .venv/bin/activate && pip install setuptools_scm
|
||||||
|
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-rocm
|
RUN . .venv/bin/activate && make build-vllm-rocm
|
||||||
|
|
||||||
# Build Flash Attention v2 kernels
|
|
||||||
FROM kernel-builder AS flash-att-v2-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-flash-att-v2 Makefile
|
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
|
||||||
RUN make build-flash-attention-v2-rocm
|
|
||||||
|
|
||||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||||
FROM kernel-builder AS custom-kernels-builder
|
FROM kernel-builder AS custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN python setup.py build
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
# Build exllama kernels
|
# Build exllama kernels
|
||||||
FROM kernel-builder AS exllama-kernels-builder
|
FROM kernel-builder AS exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
RUN python setup.py build
|
|
||||||
|
|
||||||
# Build exllama v2 kernels
|
# Build exllama v2 kernels
|
||||||
FROM kernel-builder AS exllamav2-kernels-builder
|
FROM kernel-builder AS exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/exllamav2_kernels/ .
|
||||||
|
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
RUN python setup.py build
|
FROM kernel-builder AS marlin-kernels
|
||||||
|
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||||
|
cd marlin-kernels && \
|
||||||
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
FROM base AS base-copy
|
FROM kernel-builder AS moe-kernels
|
||||||
|
ENV MOE_KERNELS_BRANCH=v0.8.2
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
|
||||||
|
cd moe-kernels && \
|
||||||
|
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
|
FROM final AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HF_HOME=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
ENV VIRTUAL_ENV=/app/.venv/
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
ENV PATH="$PATH:/app/.venv/bin/"
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
uv pip install grpcio-tools mypy-protobuf && \
|
||||||
pip install -r requirements_rocm.txt && \
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
make gen-server-raw
|
||||||
|
RUN cd server && \
|
||||||
|
pwd && \
|
||||||
|
text-generation-server --help
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
|
||||||
|
uv pip install /install/*.whl
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -211,8 +291,24 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
|
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||||
|
ENV HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|
||||||
|
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||||
|
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||||
|
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||||
|
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
|
ENV VLLM_MOE_PADDING=0
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
CMD ["--json-output"]
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
|
||||||
|
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
|
||||||
|
# CMD ["--json-output"]
|
||||||
|
126
Dockerfile_gaudi
Normal file
126
Dockerfile_gaudi
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# Those arguments are required to build the image
|
||||||
|
ARG HABANA_VERSION=1.20.0
|
||||||
|
ARG PYTORCH_VERSION=2.6.0
|
||||||
|
|
||||||
|
# Rust builder
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ENV PYO3_PYTHON="/root/.local/bin/python" \
|
||||||
|
PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
|
||||||
|
PYO3_PYTHON_VERSION="3.10"
|
||||||
|
|
||||||
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||||
|
&& . $HOME/.local/bin/env \
|
||||||
|
&& uv python install 3.10 --default --preview \
|
||||||
|
&& test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
# Text Generation Inference base image
|
||||||
|
ARG HABANA_VERSION
|
||||||
|
ARG PYTORCH_VERSION
|
||||||
|
|
||||||
|
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||||
|
|
||||||
|
ENV ATTENTION=default
|
||||||
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HF_HOME=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
|
||||||
|
RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
|
||||||
|
|
||||||
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
libssl-dev \
|
||||||
|
ca-certificates \
|
||||||
|
make \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY backends/gaudi/server server
|
||||||
|
COPY backends/gaudi/server/Makefile server/Makefile
|
||||||
|
ARG HABANA_VERSION
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install --no-deps -r requirements.txt && \
|
||||||
|
bash ./dill-0.3.8-patch.sh && \
|
||||||
|
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||||
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
|
pip install . --no-cache-dir
|
||||||
|
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
|
||||||
|
# AWS Sagemaker compatible image
|
||||||
|
FROM base AS sagemaker
|
||||||
|
|
||||||
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
|
RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["./entrypoint.sh"]
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM base
|
||||||
|
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||||
|
ENV HABANA_VISIBLE_DEVICES all
|
||||||
|
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
|
||||||
|
|
||||||
|
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||||
|
CMD ["--json-output"]
|
137
Dockerfile_intel
137
Dockerfile_intel
@ -1,6 +1,6 @@
|
|||||||
ARG PLATFORM=xpu
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -12,11 +12,14 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
python3.11-dev
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -29,20 +32,48 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
|||||||
ARG GIT_SHA
|
ARG GIT_SHA
|
||||||
ARG DOCKER_LABEL
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt --frozen
|
||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
|
|
||||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
|
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") exit 1 ;; \
|
||||||
|
*) /opt/conda/bin/conda update -y conda && \
|
||||||
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
|
esac && \
|
||||||
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
@ -52,40 +83,41 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
|||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
|
RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
|
||||||
|
|
||||||
|
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HF_HOME=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
|
||||||
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
|
||||||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
|
||||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
|
||||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
|
||||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
|
||||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
ENV TORCH_LLM_ALLREDUCE=1
|
||||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
|
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||||
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
|
||||||
|
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
@ -101,20 +133,31 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
make \
|
make \
|
||||||
g++ \
|
g++-12 \
|
||||||
|
gcc-12 \
|
||||||
git \
|
git \
|
||||||
wget \
|
wget \
|
||||||
cmake
|
cmake \
|
||||||
|
libnuma-dev
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
|
||||||
|
RUN update-alternatives --set cc /usr/bin/gcc
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
|
||||||
|
RUN update-alternatives --set c++ /usr/bin/g++
|
||||||
|
|
||||||
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
ARG PYTHON_VERSION='3.10.10'
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
# Install mamba
|
# Install mamba
|
||||||
@ -128,42 +171,40 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") exit 1 ;; \
|
||||||
|
*) /opt/conda/bin/conda update -y conda && \
|
||||||
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
|
esac && \
|
||||||
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
RUN conda install -c conda-forge gperftools mkl
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install triton==3.1.0 py-libnuma
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
|
||||||
RUN pip install triton
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
|
||||||
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||||
|
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
|
||||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
|
||||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
|
||||||
ENV KMP_BLOCKTIME=1
|
|
||||||
ENV KMP_TPAUSE=0
|
|
||||||
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
|
||||||
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
|
||||||
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -173,5 +214,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
|
ENV ATTENTION=flashdecoding-ipex
|
||||||
|
ENV PREFIX_CACHING=1
|
||||||
|
ENV PREFILL_CHUNKING=1
|
||||||
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
88
Dockerfile_llamacpp
Normal file
88
Dockerfile_llamacpp
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
|
||||||
|
|
||||||
|
ARG llamacpp_version=b4827
|
||||||
|
ARG llamacpp_cuda=OFF
|
||||||
|
ARG llamacpp_native=ON
|
||||||
|
ARG llamacpp_cpu_arm_arch=native
|
||||||
|
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
|
||||||
|
|
||||||
|
WORKDIR /opt/src
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
|
clang \
|
||||||
|
cmake \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
python3-dev \
|
||||||
|
libssl-dev \
|
||||||
|
pkg-config \
|
||||||
|
tar
|
||||||
|
|
||||||
|
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
|
||||||
|
RUN mkdir -p llama.cpp \
|
||||||
|
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
|
||||||
|
&& cd llama.cpp \
|
||||||
|
&& cmake -B build \
|
||||||
|
-DCMAKE_INSTALL_PREFIX=/usr \
|
||||||
|
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
|
||||||
|
-DCMAKE_C_COMPILER=clang \
|
||||||
|
-DCMAKE_CXX_COMPILER=clang++ \
|
||||||
|
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
|
||||||
|
-DGGML_CUDA=${llamacpp_cuda} \
|
||||||
|
-DGGML_NATIVE=${llamacpp_native} \
|
||||||
|
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
|
||||||
|
-DLLAMA_BUILD_COMMON=OFF \
|
||||||
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
|
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||||
|
-DLLAMA_BUILD_SERVER=OFF \
|
||||||
|
&& cmake --build build --parallel --config Release \
|
||||||
|
&& cmake --install build
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
FROM deps AS planner
|
||||||
|
COPY . .
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM deps AS builder
|
||||||
|
COPY --from=planner /app/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook \
|
||||||
|
--recipe-path recipe.json \
|
||||||
|
--profile release \
|
||||||
|
--package text-generation-router-llamacpp
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build \
|
||||||
|
--profile release \
|
||||||
|
--package text-generation-router-llamacpp --frozen
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt update && apt upgrade -y && apt install -y \
|
||||||
|
python3-venv \
|
||||||
|
python3-pip
|
||||||
|
|
||||||
|
RUN python3 -m venv /venv
|
||||||
|
ENV PATH="/venv/bin:$PATH"
|
||||||
|
|
||||||
|
COPY backends/llamacpp/requirements.txt requirements.txt
|
||||||
|
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
|
||||||
|
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
|
||||||
|
|
||||||
|
RUN pip3 install --no-cache-dir \
|
||||||
|
-r requirements.txt \
|
||||||
|
-e gguf-py
|
||||||
|
|
||||||
|
COPY --from=builder /usr/lib/libllama.so /usr/lib/
|
||||||
|
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
|
||||||
|
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
|
||||||
|
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
||||||
|
|
||||||
|
ENTRYPOINT ["text-generation-router-llamacpp"]
|
158
Dockerfile_trtllm
Normal file
158
Dockerfile_trtllm
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
|
||||||
|
ARG cuda_base=12.8.0
|
||||||
|
ARG build_type=release
|
||||||
|
ARG ompi_version=4.1.7
|
||||||
|
ARG sccache_gha_enabled=off
|
||||||
|
ARG actions_results_url=""
|
||||||
|
ARG actions_runtime_token=""
|
||||||
|
|
||||||
|
# CUDA dependent dependencies resolver stage
|
||||||
|
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
cmake \
|
||||||
|
curl \
|
||||||
|
gcc-14 \
|
||||||
|
g++-14 \
|
||||||
|
git \
|
||||||
|
git-lfs \
|
||||||
|
lld \
|
||||||
|
libssl-dev \
|
||||||
|
libucx-dev \
|
||||||
|
libasan8 \
|
||||||
|
libubsan1 \
|
||||||
|
ninja-build \
|
||||||
|
pkg-config \
|
||||||
|
pipx \
|
||||||
|
python3 \
|
||||||
|
python3-dev \
|
||||||
|
python3-setuptools \
|
||||||
|
tar \
|
||||||
|
wget --no-install-recommends && \
|
||||||
|
pipx ensurepath
|
||||||
|
|
||||||
|
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||||
|
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||||
|
|
||||||
|
# Install OpenMPI
|
||||||
|
FROM cuda-builder AS mpi-builder
|
||||||
|
WORKDIR /opt/src/mpi
|
||||||
|
|
||||||
|
ARG ompi_version
|
||||||
|
ENV OMPI_VERSION=${ompi_version}
|
||||||
|
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
|
||||||
|
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
|
||||||
|
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
|
||||||
|
|
||||||
|
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
|
||||||
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
||||||
|
make -j all && \
|
||||||
|
make install && \
|
||||||
|
rm -rf ${OMPI_TARBALL_FILENAME}/..
|
||||||
|
|
||||||
|
# Install TensorRT
|
||||||
|
FROM cuda-builder AS trt-builder
|
||||||
|
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||||
|
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||||
|
/opt/install_tensorrt.sh
|
||||||
|
|
||||||
|
# Build Backend
|
||||||
|
FROM cuda-builder AS tgi-builder
|
||||||
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
# Scoped global args reuse
|
||||||
|
ARG cuda_arch_list
|
||||||
|
ARG build_type
|
||||||
|
ARG sccache_gha_enabled
|
||||||
|
ARG actions_results_url
|
||||||
|
ARG actions_runtime_token
|
||||||
|
|
||||||
|
# Install Rust
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
|
||||||
|
chmod -R a+w /root/.rustup && \
|
||||||
|
chmod -R a+w /root/.cargo && \
|
||||||
|
cargo install sccache --version ">=0.10.0" --locked
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
|
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
|
||||||
|
|
||||||
|
ENV USE_LLD_LINKER=ON
|
||||||
|
ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
||||||
|
|
||||||
|
# SCCACHE Specifics args - before finding a better, more generic, way...
|
||||||
|
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
||||||
|
ENV ACTIONS_RESULTS_URL=${actions_results_url}
|
||||||
|
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY launcher launcher
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
|
||||||
|
ENV RUSTC_WRAPPER=sccache
|
||||||
|
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
|
||||||
|
RUN export CC=gcc-14 \
|
||||||
|
export CXX=g++-14 \
|
||||||
|
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
|
||||||
|
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||||
|
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
|
||||||
|
sccache --show-stats
|
||||||
|
|
||||||
|
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
|
||||||
|
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
|
pipx ensurepath && \
|
||||||
|
pipx install --include-deps transformers tokenizers
|
||||||
|
|
||||||
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is used only for the CI/CD
|
||||||
|
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
|
||||||
|
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
|
pipx ensurepath && \
|
||||||
|
pipx install --include-deps transformers tokenizers
|
||||||
|
|
||||||
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
|
|
||||||
|
# Basically we copy from target/debug instead of target/release
|
||||||
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is the final image
|
||||||
|
FROM runtime
|
||||||
|
|
||||||
|
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||||
|
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||||
|
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
|
||||||
|
|
||||||
|
ENTRYPOINT ["./text-generation-launcher"]
|
||||||
|
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
9
Makefile
9
Makefile
@ -5,13 +5,13 @@ install-server-cpu:
|
|||||||
cd server && make install-server
|
cd server && make install-server
|
||||||
|
|
||||||
install-router:
|
install-router:
|
||||||
cd router && cargo install --path .
|
cargo install --path backends/v3/
|
||||||
|
|
||||||
install-launcher:
|
install-launcher:
|
||||||
cd launcher && cargo install --path .
|
cargo install --path launcher/
|
||||||
|
|
||||||
install-benchmark:
|
install-benchmark:
|
||||||
cd benchmark && cargo install --path .
|
cargo install --path benchmark/
|
||||||
|
|
||||||
install: install-server install-router install-launcher
|
install: install-server install-router install-launcher
|
||||||
|
|
||||||
@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -rf target aml
|
rm -rf target aml
|
||||||
|
|
||||||
|
preview_doc:
|
||||||
|
doc-builder preview text-generation-inference docs/source --not_python_module
|
||||||
|
107
README.md
107
README.md
@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
||||||
<img width=560 width=315 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
|
<img width=560 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
# Text Generation Inference
|
# Text Generation Inference
|
||||||
@ -13,7 +13,7 @@
|
|||||||
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
|
||||||
to power Hugging Chat, the Inference API and Inference Endpoint.
|
to power Hugging Chat, the Inference API and Inference Endpoint.
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@ -28,6 +28,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
|||||||
- [Distributed Tracing](#distributed-tracing)
|
- [Distributed Tracing](#distributed-tracing)
|
||||||
- [Architecture](#architecture)
|
- [Architecture](#architecture)
|
||||||
- [Local install](#local-install)
|
- [Local install](#local-install)
|
||||||
|
- [Local install (Nix)](#local-install-nix)
|
||||||
- [Optimized architectures](#optimized-architectures)
|
- [Optimized architectures](#optimized-architectures)
|
||||||
- [Run locally](#run-locally)
|
- [Run locally](#run-locally)
|
||||||
- [Run](#run)
|
- [Run](#run)
|
||||||
@ -42,12 +43,15 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
|||||||
- Tensor Parallelism for faster inference on multiple GPUs
|
- Tensor Parallelism for faster inference on multiple GPUs
|
||||||
- Token streaming using Server-Sent Events (SSE)
|
- Token streaming using Server-Sent Events (SSE)
|
||||||
- Continuous batching of incoming requests for increased total throughput
|
- Continuous batching of incoming requests for increased total throughput
|
||||||
|
- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
|
||||||
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
||||||
- Quantization with :
|
- Quantization with :
|
||||||
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [GPT-Q](https://arxiv.org/abs/2210.17323)
|
- [GPT-Q](https://arxiv.org/abs/2210.17323)
|
||||||
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
|
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
|
||||||
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
|
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
|
||||||
|
- [Marlin](https://github.com/IST-DASLab/marlin)
|
||||||
|
- [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||||
@ -80,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
|||||||
volume=$PWD/data
|
volume=$PWD/data
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model
|
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
@ -92,9 +96,32 @@ curl 127.0.0.1:8080/generate_stream \
|
|||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl localhost:8080/v1/chat/completions \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is deep learning?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 20
|
||||||
|
}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
|
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.3-rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
@ -114,23 +141,24 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
|
|||||||
For example, if you want to serve the gated Llama V2 model variants:
|
For example, if you want to serve the gated Llama V2 model variants:
|
||||||
|
|
||||||
1. Go to https://huggingface.co/settings/tokens
|
1. Go to https://huggingface.co/settings/tokens
|
||||||
2. Copy your cli READ token
|
2. Copy your CLI READ token
|
||||||
3. Export `HF_TOKEN=<your cli READ token>`
|
3. Export `HF_TOKEN=<your CLI READ token>`
|
||||||
|
|
||||||
or with Docker:
|
or with Docker:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
model=meta-llama/Llama-2-7b-chat-hf
|
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
|
|
||||||
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
|
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
|
||||||
`PyTorch` to do distributed training/inference. `text-generation-inference` make
|
`PyTorch` to do distributed training/inference. `text-generation-inference` makes
|
||||||
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
|
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
|
||||||
|
|
||||||
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
|
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
|
||||||
@ -163,18 +191,32 @@ overridden with the `--otlp-service-name` argument
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
||||||
|
|
||||||
### Local install
|
### Local install
|
||||||
|
|
||||||
You can also opt to install `text-generation-inference` locally.
|
You can also opt to install `text-generation-inference` locally.
|
||||||
|
|
||||||
First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
First clone the repository and change directory into it:
|
||||||
Python 3.9, e.g. using `conda`:
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/huggingface/text-generation-inference
|
||||||
|
cd text-generation-inference
|
||||||
|
```
|
||||||
|
|
||||||
|
Then [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
||||||
|
Python 3.9, e.g. using `conda` or `python venv`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||||
|
|
||||||
|
#using conda
|
||||||
conda create -n text-generation-inference python=3.11
|
conda create -n text-generation-inference python=3.11
|
||||||
conda activate text-generation-inference
|
conda activate text-generation-inference
|
||||||
|
|
||||||
|
#using python venv
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
You may also need to install Protoc.
|
You may also need to install Protoc.
|
||||||
@ -208,6 +250,45 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
|
|||||||
sudo apt-get install libssl-dev gcc -y
|
sudo apt-get install libssl-dev gcc -y
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Local install (Nix)
|
||||||
|
|
||||||
|
Another option is to install `text-generation-inference` locally using [Nix](https://nixos.org). Currently,
|
||||||
|
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
|
||||||
|
be pulled from a binary cache, removing the need to build them locally.
|
||||||
|
|
||||||
|
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
|
||||||
|
Setting up the cache is important, otherwise Nix will build many of the dependencies
|
||||||
|
locally, which can take hours.
|
||||||
|
|
||||||
|
After that you can run TGI with `nix run`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd text-generation-inference
|
||||||
|
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
|
||||||
|
to make the CUDA driver libraries visible to Nix packages.
|
||||||
|
|
||||||
|
For TGI development, you can use the `impure` dev shell:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
nix develop .#impure
|
||||||
|
|
||||||
|
# Only needed the first time the devshell is started or after updating the protobuf.
|
||||||
|
(
|
||||||
|
cd server
|
||||||
|
mkdir text_generation_server/pb || true
|
||||||
|
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
|
||||||
|
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
|
||||||
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
touch text_generation_server/pb/__init__.py
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
All development dependencies (cargo, Python, Torch), etc. are available in this
|
||||||
|
dev shell.
|
||||||
|
|
||||||
## Optimized architectures
|
## Optimized architectures
|
||||||
|
|
||||||
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
|
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
|
||||||
@ -232,7 +313,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
|
|||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
|
|
||||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
|
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
|
||||||
@ -240,6 +321,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz
|
|||||||
|
|
||||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||||
|
|
||||||
|
Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
|
||||||
|
|
||||||
## Develop
|
## Develop
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
BIN
assets/v3_benchmarks.png
Normal file
BIN
assets/v3_benchmarks.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 209 KiB |
@ -107,20 +107,22 @@ impl Client {
|
|||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_tokens: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let mut truncate = max_prefill_tokens - n_tokens;
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
truncate = min(max_input_tokens, truncate);
|
||||||
|
}
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
@ -136,7 +138,7 @@ impl Client {
|
|||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
@ -145,6 +147,12 @@ impl Client {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
max_total_tokens - truncate
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
@ -153,9 +161,13 @@ impl Client {
|
|||||||
}),
|
}),
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
|
// Most request will have that
|
||||||
|
add_special_tokens: true,
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
@ -171,7 +183,7 @@ impl Client {
|
|||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
@ -179,7 +191,7 @@ impl Client {
|
|||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += truncate;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
@ -191,19 +203,23 @@ impl Client {
|
|||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: max_input_tokens.unwrap_or(0),
|
||||||
max_blocks: 0,
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
Ok(response.max_supported_total_tokens)
|
Ok((
|
||||||
|
response.max_supported_total_tokens,
|
||||||
|
response.max_input_tokens,
|
||||||
|
response.max_total_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
@ -214,8 +230,13 @@ impl Client {
|
|||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
271
backends/client/src/v3/sharded_client.rs
Normal file
271
backends/client/src/v3/sharded_client.rs
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
/// Multi shard Client
|
||||||
|
use crate::{v3, Health, ShardInfo};
|
||||||
|
use crate::{ClientError, Result};
|
||||||
|
|
||||||
|
use crate::v3::{Chunk, InfoResponse, Input};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
|
use v3::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: Option<u32>,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: Option<u32>,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
|
|
||||||
|
// Take the minimum value
|
||||||
|
// Different shards hold different parts of vocab, might yield
|
||||||
|
// different available block size.
|
||||||
|
let min = results
|
||||||
|
.iter()
|
||||||
|
.min()
|
||||||
|
.expect("Expect at least 1 warmup result");
|
||||||
|
Ok(*min)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
|
}),
|
||||||
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
blocks: vec![0],
|
||||||
|
slots: (0..16).collect(),
|
||||||
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
|
adapter_id: None,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
max_blocks: 1,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch, None).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
62
backends/gaudi/Makefile
Normal file
62
backends/gaudi/Makefile
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||||
|
mkfile_dir := $(dir $(mkfile_path))
|
||||||
|
root_dir := ${mkfile_dir}/../..
|
||||||
|
|
||||||
|
HABANA_VERSION := 1.20.0
|
||||||
|
PYTORCH_VERSION := 2.6.0
|
||||||
|
|
||||||
|
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||||
|
|
||||||
|
image:
|
||||||
|
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
|
||||||
|
|
||||||
|
run-local-dev-container:
|
||||||
|
docker run -it \
|
||||||
|
--runtime=habana \
|
||||||
|
--ipc=host \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--net=host \
|
||||||
|
-e HABANA_VISIBLE_DEVICES=all \
|
||||||
|
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||||
|
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||||
|
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
|
||||||
|
-e LOG_LEVEL=debug \
|
||||||
|
-e PORT=8080 \
|
||||||
|
-v /home/ubuntu/.cache/huggingface:/data \
|
||||||
|
-v $(PWD):/text-generation-inference \
|
||||||
|
-w /text-generation-inference \
|
||||||
|
vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
|
||||||
|
|
||||||
|
install-dependencies:
|
||||||
|
pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
|
||||||
|
pip install outlines~=0.0.34
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
|
||||||
|
install-server:
|
||||||
|
make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
|
||||||
|
|
||||||
|
install-router:
|
||||||
|
make -C ${root_dir} install-router
|
||||||
|
|
||||||
|
install-launcher:
|
||||||
|
make -C ${root_dir} install-launcher
|
||||||
|
|
||||||
|
# use source to load the rust in path
|
||||||
|
local-dev-install: install-dependencies
|
||||||
|
bash -c 'source "$$HOME/.cargo/env" && \
|
||||||
|
make install-server && \
|
||||||
|
make install-router && \
|
||||||
|
make install-launcher'
|
||||||
|
|
||||||
|
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
|
||||||
|
run-integration-tests:
|
||||||
|
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
|
||||||
|
DOCKER_VOLUME=${root_dir}/data \
|
||||||
|
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||||
|
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
|
||||||
|
|
||||||
|
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
|
||||||
|
capture-expected-outputs-for-integration-tests:
|
||||||
|
DOCKER_VOLUME=${root_dir}/data \
|
||||||
|
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||||
|
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
|
142
backends/gaudi/README.md
Normal file
142
backends/gaudi/README.md
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Text-generation-inference - Gaudi backend
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
|
||||||
|
|
||||||
|
## Build your own image
|
||||||
|
|
||||||
|
The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
|
||||||
|
|
||||||
|
Option 1: From the project root directory:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi image
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: From the Gaudi backend directory:
|
||||||
|
```bash
|
||||||
|
cd backends/gaudi
|
||||||
|
make image
|
||||||
|
```
|
||||||
|
|
||||||
|
You can now run the server with the following command:
|
||||||
|
|
||||||
|
Option 1: Sharded:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||||
|
volume=${HOME}/.cache/huggingface
|
||||||
|
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Non-sharded:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
hf_token=$(cat ${HOME}/.cache/huggingface/token)
|
||||||
|
volume=${HOME}/.cache/huggingface
|
||||||
|
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
This is useful if you want to run the server locally for better debugging.
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi run-local-dev-container
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the following command inside the container to install tgi for gaudi:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi local-dev-install
|
||||||
|
```
|
||||||
|
|
||||||
|
Add rust to path:
|
||||||
|
```bash
|
||||||
|
. "$HOME/.cargo/env"
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 1: Run the server (sharded model):
|
||||||
|
```bash
|
||||||
|
LOG_LEVEL=debug text-generation-launcher \
|
||||||
|
--model-id meta-llama/Llama-3.1-8B-Instruct \
|
||||||
|
--sharded true \
|
||||||
|
--num-shard 8 \
|
||||||
|
--max-input-tokens 512 \
|
||||||
|
--max-total-tokens 1024 \
|
||||||
|
--max-batch-size 8 \
|
||||||
|
--max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Run the server (non-sharded model):
|
||||||
|
```bash
|
||||||
|
LOG_LEVEL=debug text-generation-launcher \
|
||||||
|
--model-id meta-llama/Llama-3.1-8B-Instruct \
|
||||||
|
--max-input-tokens 512 \
|
||||||
|
--max-total-tokens 1024 \
|
||||||
|
--max-batch-size 4 \
|
||||||
|
--max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then test the server with the following curl command from another terminal (can be outside the container):
|
||||||
|
```bash
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration tests
|
||||||
|
|
||||||
|
To run the integration tests, you need to first build the image:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi image
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the following command to run the integration tests:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi run-integration-tests
|
||||||
|
```
|
||||||
|
|
||||||
|
To capture the expected outputs for the integration tests, you can run the following command:
|
||||||
|
```bash
|
||||||
|
make -C backends/gaudi capture-expected-outputs-for-integration-tests
|
||||||
|
```
|
||||||
|
|
||||||
|
#### How the integration tests works
|
||||||
|
The integration tests works as follows:
|
||||||
|
|
||||||
|
1. Start a tgi server in a container, similar to the command:
|
||||||
|
```bash
|
||||||
|
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||||
|
-p 8080:80 -v $volume:/data \
|
||||||
|
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
|
||||||
|
tgi-gaudi --model-id $model \
|
||||||
|
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Do a /generate request to the server, similar to the command:
|
||||||
|
```bash
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Check the output of the server against the expected output:
|
||||||
|
```python
|
||||||
|
assert curl_output == expected_output
|
||||||
|
```
|
||||||
|
|
||||||
|
This is the repeated for a set of models and configurations.
|
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
# Examples of Docker Commands for Gaudi Backend
|
||||||
|
|
||||||
|
This page gives a list of examples of docker run commands for some of the most popular models.
|
||||||
|
|
||||||
|
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
|
||||||
|
|
||||||
|
## Default Precision (BF16)
|
||||||
|
|
||||||
|
### Llama3.1-8B on 1 card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama3.1-70B 8 cards (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama2-7B on 1 Card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-7b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama2-70B on 8 cards (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-70b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llava-v1.6-Mistral-7B on 1 card (BF16)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## FP8 Precision
|
||||||
|
|
||||||
|
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
|
||||||
|
|
||||||
|
## Llama3.1-8B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama3.1-70B on 8 cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama2-7B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-7b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||||
|
-e BATCH_BUCKET_SIZE=32 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llama2-70B on 8 Cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-70b-chat-hf
|
||||||
|
hf_token=YOUR_ACCESS_TOKEN
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e HF_TOKEN=$hf_token \
|
||||||
|
-e MAX_TOTAL_TOKENS=2048 \
|
||||||
|
-e BATCH_BUCKET_SIZE=256 \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||||
|
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||||
|
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||||
|
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 \
|
||||||
|
--runtime=habana \
|
||||||
|
--cap-add=sys_nice \
|
||||||
|
--ipc=host \
|
||||||
|
-v $volume:/data \
|
||||||
|
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||||
|
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||||
|
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||||
|
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||||
|
-e BATCH_BUCKET_SIZE=1 \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||||
|
--model-id $model \
|
||||||
|
--sharded true --num-shard 8 \
|
||||||
|
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||||
|
--max-total-tokens 8192 --max-batch-size 4
|
||||||
|
```
|
164
backends/gaudi/server/.gitignore
vendored
Normal file
164
backends/gaudi/server/.gitignore
vendored
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
text_generation_server/__pycache__/
|
||||||
|
text_generation_server/pb/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
transformers
|
||||||
|
safetensors
|
||||||
|
flash-attention/
|
||||||
|
flash-attention-v2/
|
||||||
|
vllm/
|
||||||
|
llm-awq/
|
||||||
|
eetq/
|
||||||
|
mamba/
|
38
backends/gaudi/server/Makefile
Normal file
38
backends/gaudi/server/Makefile
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
include Makefile-flash-att
|
||||||
|
include Makefile-flash-att-v2
|
||||||
|
include Makefile-vllm
|
||||||
|
include Makefile-awq
|
||||||
|
include Makefile-eetq
|
||||||
|
include Makefile-selective-scan
|
||||||
|
|
||||||
|
PROTO_PATH ?= ../proto/v3
|
||||||
|
|
||||||
|
unit-tests:
|
||||||
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
||||||
|
gen-server:
|
||||||
|
# Compile protos
|
||||||
|
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
|
||||||
|
mkdir text_generation_server/pb || true
|
||||||
|
python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
|
||||||
|
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
|
||||||
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
|
install: gen-server
|
||||||
|
pip install pip --upgrade
|
||||||
|
pip install --no-deps -r requirements.txt
|
||||||
|
pip install -e "."
|
||||||
|
|
||||||
|
run-dev:
|
||||||
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
|
install-poetry:
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
update-lock:
|
||||||
|
rm poetry.lock
|
||||||
|
poetry lock --no-update
|
||||||
|
|
||||||
|
export-requirements:
|
||||||
|
poetry export -o requirements.txt --without-hashes
|
15
backends/gaudi/server/Makefile-awq
Normal file
15
backends/gaudi/server/Makefile-awq
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Fork that adds only the correct stream to this kernel in order
|
||||||
|
# to make cuda graphs work.
|
||||||
|
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
|
||||||
|
|
||||||
|
awq:
|
||||||
|
rm -rf llm-awq
|
||||||
|
git clone https://github.com/huggingface/llm-awq
|
||||||
|
|
||||||
|
build-awq: awq
|
||||||
|
cd llm-awq/ && git fetch && git checkout $(awq_commit)
|
||||||
|
cd llm-awq/awq/kernels && python setup.py build
|
||||||
|
|
||||||
|
install-awq: build-awq
|
||||||
|
pip uninstall awq_inference_engine -y || true
|
||||||
|
cd llm-awq/awq/kernels && python setup.py install
|
13
backends/gaudi/server/Makefile-eetq
Normal file
13
backends/gaudi/server/Makefile-eetq
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
|
||||||
|
|
||||||
|
eetq:
|
||||||
|
# Clone eetq
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||||
|
|
||||||
|
build-eetq: eetq
|
||||||
|
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
|
||||||
|
cd eetq && python setup.py build
|
||||||
|
|
||||||
|
install-eetq: build-eetq
|
||||||
|
cd eetq && python setup.py install
|
@ -1,10 +1,10 @@
|
|||||||
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
fbgemm_commit := v0.8.0
|
||||||
|
|
||||||
build-fbgemm:
|
build-fbgemm:
|
||||||
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
@if [ ! -d "fbgemm" ]; then \
|
||||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
||||||
cp fbgemm_remove_unused.patch fbgemm && \
|
fi
|
||||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
||||||
git submodule update --init --recursive && \
|
git submodule update --init --recursive && \
|
||||||
cd fbgemm_gpu && \
|
cd fbgemm_gpu && \
|
||||||
pip install -r requirements.txt && \
|
pip install -r requirements.txt && \
|
12
backends/gaudi/server/Makefile-flash-att
Normal file
12
backends/gaudi/server/Makefile-flash-att
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||||
|
|
||||||
|
build-flash-attention:
|
||||||
|
if [ ! -d 'flash-attention' ]; then \
|
||||||
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
|
git clone https://github.com/HazyResearch/flash-attention.git; \
|
||||||
|
fi
|
||||||
|
cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
|
||||||
|
MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
|
||||||
|
|
||||||
|
install-flash-attention: build-flash-attention
|
||||||
|
cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
|
21
backends/gaudi/server/Makefile-flash-att-v2
Normal file
21
backends/gaudi/server/Makefile-flash-att-v2
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
|
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||||
|
|
||||||
|
build-flash-attention-v2-cuda:
|
||||||
|
pip install -U packaging wheel
|
||||||
|
pip install flash-attn==$(flash_att_v2_commit_cuda)
|
||||||
|
|
||||||
|
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||||
|
echo "Flash v2 installed"
|
||||||
|
|
||||||
|
build-flash-attention-v2-rocm:
|
||||||
|
if [ ! -d 'flash-attention-v2' ]; then \
|
||||||
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
|
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||||
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||||
|
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||||
|
cd flash-attention-v2 && \
|
||||||
|
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
28
backends/gaudi/server/Makefile-selective-scan
Normal file
28
backends/gaudi/server/Makefile-selective-scan
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
|
||||||
|
|
||||||
|
causal-conv1d:
|
||||||
|
rm -rf causal-conv1d
|
||||||
|
git clone https://github.com/Dao-AILab/causal-conv1d.git
|
||||||
|
|
||||||
|
build-causal-conv1d: causal-conv1d
|
||||||
|
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
|
||||||
|
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
|
||||||
|
|
||||||
|
install-causal-conv1d: build-causal-conv1d
|
||||||
|
pip uninstall causal-conv1d -y || true
|
||||||
|
cd causal-conv1d/ && pip install .
|
||||||
|
|
||||||
|
# selective-scan dependends on causal-conv1d
|
||||||
|
selective-scan:
|
||||||
|
rm -rf mamba
|
||||||
|
git clone https://github.com/state-spaces/mamba.git mamba
|
||||||
|
|
||||||
|
build-selective-scan: selective-scan
|
||||||
|
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
|
||||||
|
cd mamba && python setup.py build
|
||||||
|
|
||||||
|
install-selective-scan: install-causal-conv1d build-selective-scan
|
||||||
|
pip uninstall selective-scan-cuda -y || true
|
||||||
|
cd mamba && pip install .
|
||||||
|
|
||||||
|
build-all: build-causal-conv1d build-selective-scan
|
23
backends/gaudi/server/Makefile-vllm
Normal file
23
backends/gaudi/server/Makefile-vllm
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
|
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||||
|
build-vllm-cuda:
|
||||||
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
git clone https://github.com/Narsil/vllm.git vllm; \
|
||||||
|
fi
|
||||||
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
||||||
|
|
||||||
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
|
||||||
|
|
||||||
|
build-vllm-rocm:
|
||||||
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||||
|
fi
|
||||||
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
|
install-vllm-rocm: build-vllm-rocm
|
||||||
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
|
15
backends/gaudi/server/README.md
Normal file
15
backends/gaudi/server/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Text Generation Inference Python gRPC Server
|
||||||
|
|
||||||
|
A Python gRPC server for Text Generation Inference
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make install
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make run-dev
|
||||||
|
```
|
91
backends/gaudi/server/dill-0.3.7-patch.sh
Normal file
91
backends/gaudi/server/dill-0.3.7-patch.sh
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
|
||||||
|
pushd dill
|
||||||
|
cat <<EOF > dill-0.3.7.patch
|
||||||
|
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||||
|
index d0cf543..f6eb662 100644
|
||||||
|
--- a/dill/_dill.py
|
||||||
|
+++ b/dill/_dill.py
|
||||||
|
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||||
|
XRangeType = range
|
||||||
|
from types import MappingProxyType as DictProxyType, new_class
|
||||||
|
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||||
|
-import __main__ as _main_module
|
||||||
|
+class _LazyMainModule(object):
|
||||||
|
+ _module = None
|
||||||
|
+ @property
|
||||||
|
+ def module(self):
|
||||||
|
+ if self._module is None:
|
||||||
|
+ import __main__ as _m_module
|
||||||
|
+ self._module = _m_module
|
||||||
|
+ return self._module
|
||||||
|
+_main_module = _LazyMainModule()
|
||||||
|
import marshal
|
||||||
|
import gc
|
||||||
|
# import zlib
|
||||||
|
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
|
||||||
|
_fmode = kwds.pop('fmode', None)
|
||||||
|
_recurse = kwds.pop('recurse', None)
|
||||||
|
StockPickler.__init__(self, file, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._diff_cache = {}
|
||||||
|
self._byref = settings['byref'] if _byref is None else _byref
|
||||||
|
self._strictio = False #_strictio
|
||||||
|
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
|
||||||
|
settings = Pickler.settings
|
||||||
|
_ignore = kwds.pop('ignore', None)
|
||||||
|
StockUnpickler.__init__(self, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||||
|
|
||||||
|
def load(self): #NOTE: if settings change, need to update attributes
|
||||||
|
obj = StockUnpickler.load(self)
|
||||||
|
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||||
|
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||||
|
if not self._ignore:
|
||||||
|
# point obj class to main
|
||||||
|
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||||
|
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
|
||||||
|
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||||
|
logger.trace(pickler, "# D1")
|
||||||
|
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||||
|
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||||
|
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||||
|
logger.trace(pickler, "# D3")
|
||||||
|
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||||
|
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||||
|
and type(obj['__name__']) is str \\
|
||||||
|
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||||
|
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||||
|
diff --git a/dill/session.py b/dill/session.py
|
||||||
|
index 74234ab..1be8d89 100644
|
||||||
|
--- a/dill/session.py
|
||||||
|
+++ b/dill/session.py
|
||||||
|
@@ -233,7 +233,7 @@ def dump_module(
|
||||||
|
protocol = settings['protocol']
|
||||||
|
main = module
|
||||||
|
if main is None:
|
||||||
|
- main = _main_module
|
||||||
|
+ main = _main_module.module
|
||||||
|
elif isinstance(main, str):
|
||||||
|
main = _import_module(main)
|
||||||
|
if not isinstance(main, ModuleType):
|
||||||
|
@@ -501,7 +501,7 @@ def load_module(
|
||||||
|
pass
|
||||||
|
assert loaded is main
|
||||||
|
_restore_modules(unpickler, main)
|
||||||
|
- if main is _main_module or main is module:
|
||||||
|
+ if main is _main_module.module or main is module:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return main
|
||||||
|
|
||||||
|
EOF
|
||||||
|
git apply dill-0.3.7.patch
|
||||||
|
python -m pip install .
|
||||||
|
popd
|
||||||
|
rm -fr dill
|
91
backends/gaudi/server/dill-0.3.8-patch.sh
Normal file
91
backends/gaudi/server/dill-0.3.8-patch.sh
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
|
||||||
|
pushd dill
|
||||||
|
cat <<EOF > dill-0.3.8.patch
|
||||||
|
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||||
|
index d42432f..1d251e6 100644
|
||||||
|
--- a/dill/_dill.py
|
||||||
|
+++ b/dill/_dill.py
|
||||||
|
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||||
|
XRangeType = range
|
||||||
|
from types import MappingProxyType as DictProxyType, new_class
|
||||||
|
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||||
|
-import __main__ as _main_module
|
||||||
|
+class _LazyMainModule(object):
|
||||||
|
+ _module = None
|
||||||
|
+ @property
|
||||||
|
+ def module(self):
|
||||||
|
+ if self._module is None:
|
||||||
|
+ import __main__ as _m_module
|
||||||
|
+ self._module = _m_module
|
||||||
|
+ return self._module
|
||||||
|
+_main_module = _LazyMainModule()
|
||||||
|
import marshal
|
||||||
|
import gc
|
||||||
|
# import zlib
|
||||||
|
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
|
||||||
|
_fmode = kwds.pop('fmode', None)
|
||||||
|
_recurse = kwds.pop('recurse', None)
|
||||||
|
StockPickler.__init__(self, file, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._diff_cache = {}
|
||||||
|
self._byref = settings['byref'] if _byref is None else _byref
|
||||||
|
self._strictio = False #_strictio
|
||||||
|
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
|
||||||
|
settings = Pickler.settings
|
||||||
|
_ignore = kwds.pop('ignore', None)
|
||||||
|
StockUnpickler.__init__(self, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||||
|
|
||||||
|
def load(self): #NOTE: if settings change, need to update attributes
|
||||||
|
obj = StockUnpickler.load(self)
|
||||||
|
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||||
|
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||||
|
if not self._ignore:
|
||||||
|
# point obj class to main
|
||||||
|
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||||
|
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
|
||||||
|
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||||
|
logger.trace(pickler, "# D1")
|
||||||
|
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||||
|
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||||
|
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||||
|
logger.trace(pickler, "# D3")
|
||||||
|
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||||
|
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||||
|
and type(obj['__name__']) is str \\
|
||||||
|
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||||
|
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||||
|
diff --git a/dill/session.py b/dill/session.py
|
||||||
|
index e91068a..a921b43 100644
|
||||||
|
--- a/dill/session.py
|
||||||
|
+++ b/dill/session.py
|
||||||
|
@@ -233,7 +233,7 @@ def dump_module(
|
||||||
|
protocol = settings['protocol']
|
||||||
|
main = module
|
||||||
|
if main is None:
|
||||||
|
- main = _main_module
|
||||||
|
+ main = _main_module.module
|
||||||
|
elif isinstance(main, str):
|
||||||
|
main = _import_module(main)
|
||||||
|
if not isinstance(main, ModuleType):
|
||||||
|
@@ -501,7 +501,7 @@ def load_module(
|
||||||
|
pass
|
||||||
|
assert loaded is main
|
||||||
|
_restore_modules(unpickler, main)
|
||||||
|
- if main is _main_module or main is module:
|
||||||
|
+ if main is _main_module.module or main is module:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return main
|
||||||
|
|
||||||
|
EOF
|
||||||
|
git apply dill-0.3.8.patch
|
||||||
|
python -m pip install .
|
||||||
|
popd
|
||||||
|
rm -fr dill
|
@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from test_model import TEST_CONFIGS
|
||||||
|
|
||||||
|
UNKNOWN_CONFIGS = {
|
||||||
|
name: config
|
||||||
|
for name, config in TEST_CONFIGS.items()
|
||||||
|
if config["expected_greedy_output"] == "unknown"
|
||||||
|
or config["expected_batch_output"] == "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
|
||||||
|
def test_config(request) -> Dict[str, Any]:
|
||||||
|
"""Fixture that provides model configurations for testing."""
|
||||||
|
test_config = UNKNOWN_CONFIGS[request.param]
|
||||||
|
test_config["test_name"] = request.param
|
||||||
|
return test_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_name(test_config):
|
||||||
|
yield test_config["test_name"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tgi_service(launcher, test_config, test_name) -> Generator:
|
||||||
|
"""Fixture that provides a TGI service for testing."""
|
||||||
|
with launcher(test_config["model_id"], test_name) as service:
|
||||||
|
yield service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
|
||||||
|
"""Test that captures expected outputs for models with unknown outputs."""
|
||||||
|
print(f"Testing {test_name} with {test_config['model_id']}")
|
||||||
|
|
||||||
|
# Wait for service to be ready
|
||||||
|
await tgi_service.health(1000)
|
||||||
|
client = tgi_service.client
|
||||||
|
|
||||||
|
# Test single request (greedy)
|
||||||
|
print("Testing single request...")
|
||||||
|
response = await client.generate(
|
||||||
|
test_config["input"],
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
greedy_output = response.generated_text
|
||||||
|
|
||||||
|
# Test multiple requests (batch)
|
||||||
|
print("Testing batch requests...")
|
||||||
|
responses = []
|
||||||
|
for _ in range(4):
|
||||||
|
response = await client.generate(
|
||||||
|
test_config["input"],
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
responses.append(response.generated_text)
|
||||||
|
|
||||||
|
# Store results in a JSON file
|
||||||
|
output_file = "server/integration-tests/expected_outputs.json"
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Try to load existing results if file exists
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
with open(output_file, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
|
||||||
|
# Update results for this model
|
||||||
|
results[test_name] = {
|
||||||
|
"model_id": test_config["model_id"],
|
||||||
|
"input": test_config["input"],
|
||||||
|
"greedy_output": greedy_output,
|
||||||
|
"batch_outputs": responses,
|
||||||
|
"args": test_config["args"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save updated results
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\nResults for {test_name} saved to {output_file}")
|
292
backends/gaudi/server/integration-tests/conftest.py
Normal file
292
backends/gaudi/server/integration-tests/conftest.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import List
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import docker
|
||||||
|
import pytest
|
||||||
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
from docker.errors import NotFound
|
||||||
|
from loguru import logger
|
||||||
|
from test_model import TEST_CONFIGS
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import Response
|
||||||
|
|
||||||
|
# Use the latest image from the local docker build
|
||||||
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
|
||||||
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
|
||||||
|
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
HF_TOKEN is not None
|
||||||
|
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
|
||||||
|
|
||||||
|
if DOCKER_VOLUME is None:
|
||||||
|
logger.warning(
|
||||||
|
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
|
||||||
|
|
||||||
|
BASE_ENV = {
|
||||||
|
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||||
|
"LOG_LEVEL": LOG_LEVEL,
|
||||||
|
"HF_TOKEN": os.getenv("HF_TOKEN", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HABANA_RUN_ARGS = {
|
||||||
|
"runtime": "habana",
|
||||||
|
"ipc_mode": "host",
|
||||||
|
"cap_add": ["sys_nice"],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||||
|
level="INFO",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_container_logs(container, test_name):
|
||||||
|
"""Stream container logs in a separate thread."""
|
||||||
|
try:
|
||||||
|
for log in container.logs(stream=True, follow=True):
|
||||||
|
print(
|
||||||
|
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
|
||||||
|
end="",
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming container logs: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class LauncherHandle:
|
||||||
|
def __init__(self, port: int):
|
||||||
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
|
||||||
|
|
||||||
|
def _inner_health(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def health(self, timeout: int = 60):
|
||||||
|
assert timeout > 0
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(f"Starting health check with timeout of {timeout}s")
|
||||||
|
|
||||||
|
for attempt in range(timeout):
|
||||||
|
if not self._inner_health():
|
||||||
|
logger.error("Launcher crashed during health check")
|
||||||
|
raise RuntimeError("Launcher crashed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.client.generate("test")
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(f"Health check passed after {elapsed:.1f}s")
|
||||||
|
return
|
||||||
|
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||||
|
if attempt == timeout - 1:
|
||||||
|
logger.error(f"Health check failed after {timeout}s: {str(e)}")
|
||||||
|
raise RuntimeError(f"Health check failed: {str(e)}")
|
||||||
|
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
|
||||||
|
logger.debug(
|
||||||
|
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during health check: {str(e)}")
|
||||||
|
# Get full traceback for debugging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerLauncherHandle(LauncherHandle):
|
||||||
|
def __init__(self, docker_client, container_name, port: int):
|
||||||
|
super(ContainerLauncherHandle, self).__init__(port)
|
||||||
|
self.docker_client = docker_client
|
||||||
|
self.container_name = container_name
|
||||||
|
|
||||||
|
def _inner_health(self) -> bool:
|
||||||
|
try:
|
||||||
|
container = self.docker_client.containers.get(self.container_name)
|
||||||
|
status = container.status
|
||||||
|
if status not in ["running", "created"]:
|
||||||
|
logger.warning(f"Container status is {status}")
|
||||||
|
# Get container logs for debugging
|
||||||
|
logs = container.logs().decode("utf-8")
|
||||||
|
logger.debug(f"Container logs:\n{logs}")
|
||||||
|
return status in ["running", "created"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking container health: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessLauncherHandle(LauncherHandle):
|
||||||
|
def __init__(self, process, port: int):
|
||||||
|
super(ProcessLauncherHandle, self).__init__(port)
|
||||||
|
self.process = process
|
||||||
|
|
||||||
|
def _inner_health(self) -> bool:
|
||||||
|
return self.process.poll() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def data_volume():
|
||||||
|
tmpdir = TemporaryDirectory()
|
||||||
|
yield tmpdir.name
|
||||||
|
try:
|
||||||
|
# Cleanup the temporary directory using sudo as it contains root files created by the container
|
||||||
|
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Error cleaning up temporary directory: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def launcher(data_volume):
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def docker_launcher(
|
||||||
|
model_id: str,
|
||||||
|
test_name: str,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Starting docker launcher for model {model_id} and test {test_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get a random available port
|
||||||
|
def get_free_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
s.listen(1)
|
||||||
|
port = s.getsockname()[1]
|
||||||
|
return port
|
||||||
|
|
||||||
|
port = get_free_port()
|
||||||
|
logger.debug(f"Using port {port}")
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
|
||||||
|
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
logger.info(
|
||||||
|
f"Stopping existing container {container_name} for test {test_name}"
|
||||||
|
)
|
||||||
|
container.stop()
|
||||||
|
container.wait()
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling existing container: {str(e)}")
|
||||||
|
|
||||||
|
model_name = next(
|
||||||
|
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
|
||||||
|
|
||||||
|
env = BASE_ENV.copy()
|
||||||
|
|
||||||
|
# Add model_id to env
|
||||||
|
env["MODEL_ID"] = model_id
|
||||||
|
|
||||||
|
# Add env config that is definied in the fixture parameter
|
||||||
|
if "env_config" in TEST_CONFIGS[model_name]:
|
||||||
|
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
|
||||||
|
|
||||||
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||||
|
logger.debug(f"Using volume {volumes}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Creating container with name {container_name}")
|
||||||
|
|
||||||
|
# Log equivalent docker run command for debugging, this is not actually executed
|
||||||
|
container = client.containers.run(
|
||||||
|
DOCKER_IMAGE,
|
||||||
|
command=tgi_args,
|
||||||
|
name=container_name,
|
||||||
|
environment=env,
|
||||||
|
detach=True,
|
||||||
|
volumes=volumes,
|
||||||
|
ports={"80/tcp": port},
|
||||||
|
**HABANA_RUN_ARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Container {container_name} started successfully")
|
||||||
|
|
||||||
|
# Start log streaming in a background thread
|
||||||
|
log_thread = threading.Thread(
|
||||||
|
target=stream_container_logs,
|
||||||
|
args=(container, test_name),
|
||||||
|
daemon=True, # This ensures the thread will be killed when the main program exits
|
||||||
|
)
|
||||||
|
log_thread.start()
|
||||||
|
|
||||||
|
# Add a small delay to allow container to initialize
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Check container status after creation
|
||||||
|
status = container.status
|
||||||
|
logger.debug(f"Initial container status: {status}")
|
||||||
|
if status not in ["running", "created"]:
|
||||||
|
logs = container.logs().decode("utf-8")
|
||||||
|
logger.error(f"Container failed to start properly. Logs:\n{logs}")
|
||||||
|
|
||||||
|
yield ContainerLauncherHandle(client, container.name, port)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting container: {str(e)}")
|
||||||
|
# Get full traceback for debugging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
container = client.containers.get(container_name)
|
||||||
|
logger.info(f"Stopping container {container_name}")
|
||||||
|
container.stop()
|
||||||
|
container.wait()
|
||||||
|
|
||||||
|
container_output = container.logs().decode("utf-8")
|
||||||
|
print(container_output, file=sys.stderr)
|
||||||
|
|
||||||
|
container.remove()
|
||||||
|
logger.info(f"Container {container_name} removed successfully")
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning up container: {str(e)}")
|
||||||
|
|
||||||
|
return docker_launcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def generate_load():
|
||||||
|
async def generate_load_inner(
|
||||||
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||||
|
) -> List[Response]:
|
||||||
|
try:
|
||||||
|
futures = [
|
||||||
|
client.generate(
|
||||||
|
prompt,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
for _ in range(n)
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*futures)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating load: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return generate_load_inner
|
2
backends/gaudi/server/integration-tests/pytest.ini
Normal file
2
backends/gaudi/server/integration-tests/pytest.ini
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
7
backends/gaudi/server/integration-tests/requirements.txt
Normal file
7
backends/gaudi/server/integration-tests/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
pytest >= 8.3.5
|
||||||
|
pytest-asyncio >= 0.26.0
|
||||||
|
docker >= 7.1.0
|
||||||
|
Levenshtein >= 0.27.1
|
||||||
|
loguru >= 0.7.3
|
||||||
|
aiohttp >= 3.11.14
|
||||||
|
text-generation
|
276
backends/gaudi/server/integration-tests/test_model.py
Normal file
276
backends/gaudi/server/integration-tests/test_model.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
import pytest
|
||||||
|
from Levenshtein import distance as levenshtein_distance
|
||||||
|
|
||||||
|
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
|
||||||
|
TEST_CONFIGS = {
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct-shared": {
|
||||||
|
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||||||
|
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||||||
|
"args": [
|
||||||
|
"--sharded",
|
||||||
|
"true",
|
||||||
|
"--num-shard",
|
||||||
|
"8",
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"8",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": {
|
||||||
|
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||||
|
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||||
|
"env_config": {},
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf": {
|
||||||
|
"model_id": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||||
|
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.3": {
|
||||||
|
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"bigcode/starcoder2-3b": {
|
||||||
|
"model_id": "bigcode/starcoder2-3b",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"google/gemma-7b-it": {
|
||||||
|
"model_id": "google/gemma-7b-it",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"Qwen/Qwen2-0.5B-Instruct": {
|
||||||
|
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||||
|
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
"--max-batch-prefill-tokens",
|
||||||
|
"2048",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"tiiuae/falcon-7b-instruct": {
|
||||||
|
"model_id": "tiiuae/falcon-7b-instruct",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||||
|
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"microsoft/phi-1_5": {
|
||||||
|
"model_id": "microsoft/phi-1_5",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||||
|
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"openai-community/gpt2": {
|
||||||
|
"model_id": "openai-community/gpt2",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"facebook/opt-125m": {
|
||||||
|
"model_id": "facebook/opt-125m",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||||
|
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"EleutherAI/gpt-j-6b": {
|
||||||
|
"model_id": "EleutherAI/gpt-j-6b",
|
||||||
|
"input": "What is Deep Learning?",
|
||||||
|
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||||
|
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||||
|
"args": [
|
||||||
|
"--max-input-tokens",
|
||||||
|
"512",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"1024",
|
||||||
|
"--max-batch-size",
|
||||||
|
"4",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Testing {len(TEST_CONFIGS)} models")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
|
||||||
|
def test_config(request) -> Dict[str, Any]:
|
||||||
|
"""Fixture that provides model configurations for testing."""
|
||||||
|
test_config = TEST_CONFIGS[request.param]
|
||||||
|
test_config["test_name"] = request.param
|
||||||
|
return test_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_id(test_config):
|
||||||
|
yield test_config["model_id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_name(test_config):
|
||||||
|
yield test_config["test_name"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def expected_outputs(test_config):
|
||||||
|
return {
|
||||||
|
"greedy": test_config["expected_greedy_output"],
|
||||||
|
# "sampling": model_config["expected_sampling_output"],
|
||||||
|
"batch": test_config["expected_batch_output"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def input(test_config):
|
||||||
|
return test_config["input"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tgi_service(launcher, model_id, test_name):
|
||||||
|
with launcher(model_id, test_name) as tgi_service:
|
||||||
|
yield tgi_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def tgi_client(tgi_service) -> AsyncClient:
|
||||||
|
await tgi_service.health(1000)
|
||||||
|
return tgi_service.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_single_request(
|
||||||
|
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
|
||||||
|
):
|
||||||
|
# Bounded greedy decoding without input
|
||||||
|
response = await tgi_client.generate(
|
||||||
|
input,
|
||||||
|
max_new_tokens=32,
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 32
|
||||||
|
assert response.generated_text == expected_outputs["greedy"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_multiple_requests(
|
||||||
|
tgi_client, generate_load, expected_outputs, input
|
||||||
|
):
|
||||||
|
num_requests = 4
|
||||||
|
responses = await generate_load(
|
||||||
|
tgi_client,
|
||||||
|
input,
|
||||||
|
max_new_tokens=32,
|
||||||
|
n=num_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
expected = expected_outputs["batch"]
|
||||||
|
for r in responses:
|
||||||
|
assert r.details.generated_tokens == 32
|
||||||
|
# Compute the similarity with the expectation using the levenshtein distance
|
||||||
|
# We should not have more than two substitutions or additions
|
||||||
|
assert levenshtein_distance(r.generated_text, expected) < 3
|
3014
backends/gaudi/server/poetry.lock
generated
Normal file
3014
backends/gaudi/server/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
45
backends/gaudi/server/pyproject.toml
Normal file
45
backends/gaudi/server/pyproject.toml
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "text-generation-server"
|
||||||
|
version = "2.0.4"
|
||||||
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.9,<3.13"
|
||||||
|
protobuf = "^5.0"
|
||||||
|
grpcio = "^1.71.1"
|
||||||
|
grpcio-status = "*"
|
||||||
|
grpcio-reflection = "*"
|
||||||
|
grpc-interceptor = "^0.15.0"
|
||||||
|
typer = "^0.15.0"
|
||||||
|
loguru = "^0.7.3"
|
||||||
|
opentelemetry-api = "^1.32.0"
|
||||||
|
opentelemetry-exporter-otlp = "^1.32.0"
|
||||||
|
opentelemetry-instrumentation-grpc = "^0.53b0"
|
||||||
|
hf-transfer = "^0.1.9"
|
||||||
|
sentencepiece = "^0.2.0"
|
||||||
|
peft = "^0.15"
|
||||||
|
optimum-habana = "1.17"
|
||||||
|
transformers = "^4.49"
|
||||||
|
numpy = "^1.26"
|
||||||
|
accelerate = "^0.33"
|
||||||
|
outlines= { version = "^0.0.36", optional = true }
|
||||||
|
prometheus-client = "^0.21.1"
|
||||||
|
py-cpuinfo = "^9.0.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
grpcio-tools = "*"
|
||||||
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.poetry.requires-plugins]
|
||||||
|
poetry-plugin-export = ">=1.8"
|
101
backends/gaudi/server/requirements.txt
Normal file
101
backends/gaudi/server/requirements.txt
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
|
||||||
|
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
|
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
|
@ -0,0 +1,13 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchData,
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdapterBatchData",
|
||||||
|
"AdapterBatchMetadata",
|
||||||
|
]
|
@ -0,0 +1,30 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/config.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import AdapterWeights
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleMap:
|
||||||
|
module_name: str
|
||||||
|
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterConfig(ABC):
|
||||||
|
base_model_name_or_path: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
pass
|
471
backends/gaudi/server/text_generation_server/adapters/lora.py
Normal file
471
backends/gaudi/server/text_generation_server/adapters/lora.py
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig as _LoraConfig
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
AdapterWeights,
|
||||||
|
BatchAdapterWeights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
BGMV_MAX_RANK,
|
||||||
|
MAX_RANK_CUSTOM,
|
||||||
|
get_tmp_tensors,
|
||||||
|
orient_for_rank,
|
||||||
|
pad_rank,
|
||||||
|
use_cutlass_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
|
block_size = size // world_size
|
||||||
|
start = offset + rank * block_size
|
||||||
|
stop = offset + (rank + 1) * block_size
|
||||||
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def shard_on_dim(
|
||||||
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||||
|
):
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
size = t.shape[dim]
|
||||||
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = t[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = t[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def shard_lora_weights(
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
split_dim: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
# [hidden_size, r]
|
||||||
|
weights_a = [
|
||||||
|
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||||
|
]
|
||||||
|
|
||||||
|
# [r, hidden_size]
|
||||||
|
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||||
|
|
||||||
|
return weights_a, weights_b
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig(AdapterConfig):
|
||||||
|
r: int
|
||||||
|
target_modules: Optional[Union[List[str], str]]
|
||||||
|
fan_in_fan_out: bool
|
||||||
|
lora_alpha: int
|
||||||
|
use_rslora: bool
|
||||||
|
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
adapter_weight_names = set()
|
||||||
|
module_map = {}
|
||||||
|
for weight_name in weight_names:
|
||||||
|
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||||
|
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||||
|
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_map[weight_name] = {
|
||||||
|
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||||
|
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||||
|
}
|
||||||
|
adapter_weight_names.add(lora_a_name)
|
||||||
|
adapter_weight_names.add(lora_b_name)
|
||||||
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
|
return cls(
|
||||||
|
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||||
|
r=hf_config.r,
|
||||||
|
target_modules=hf_config.target_modules,
|
||||||
|
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||||
|
lora_alpha=hf_config.lora_alpha,
|
||||||
|
use_rslora=(
|
||||||
|
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraWeights(AdapterWeights):
|
||||||
|
"""LoRA weights for a single adapter merged across all layers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
adapter_config: LoraConfig,
|
||||||
|
):
|
||||||
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||||
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||||
|
|
||||||
|
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||||
|
self._is_transposed = False
|
||||||
|
|
||||||
|
# [num_layers, hidden_size, r]
|
||||||
|
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||||
|
self._weights_a = torch.stack(weights_a)
|
||||||
|
|
||||||
|
# [num_layers, r, hidden_size]
|
||||||
|
self._weights_b = torch.stack(weights_b)
|
||||||
|
|
||||||
|
self.adapter_config = adapter_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
def _transpose_weights(self):
|
||||||
|
if self._use_cutlass_shrink:
|
||||||
|
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||||
|
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||||
|
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||||
|
self._is_transposed = not self._is_transposed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
# prepare pre-loaded lora weights for use in the model.
|
||||||
|
#
|
||||||
|
# this method processes and organizes lora weights for a specific layer type across all layers:
|
||||||
|
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
|
||||||
|
# - retrieves weights from `module_map` based on the `layer_type`.
|
||||||
|
# - processes `nlayers` number of layers.
|
||||||
|
# - converts weights to the specified `dtype`.
|
||||||
|
# - shards weights across `world_size` number of processes using the `process_group`.
|
||||||
|
# - maps weights to specific layers using `target_to_layer`.
|
||||||
|
# - tracks `unused_weight_names` to identify any unused weights.
|
||||||
|
#
|
||||||
|
# the method handles weight transposition, scaling, and padding to ensure compatibility
|
||||||
|
# with SGMV or BGMV operations.
|
||||||
|
@classmethod
|
||||||
|
def prepare_weights(
|
||||||
|
cls,
|
||||||
|
config: LoraConfig,
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
nlayers: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_size: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
lora_a_list = [None] * nlayers
|
||||||
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
for layer_id in range(nlayers):
|
||||||
|
key = (layer_id, layer_type)
|
||||||
|
weight_name, layer = target_to_layer[key]
|
||||||
|
base_weight = layer.base_layer.linear.weight
|
||||||
|
base_device = base_weight.device
|
||||||
|
|
||||||
|
if weight_name not in module_map:
|
||||||
|
# There is no LoRA weight for this layer type in the adapter
|
||||||
|
return None
|
||||||
|
|
||||||
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
|
lora_a = lora_a.to(base_device, dtype)
|
||||||
|
|
||||||
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
|
lora_b = lora_b.to(base_device, dtype)
|
||||||
|
|
||||||
|
scale = get_scaling_factor(
|
||||||
|
config.lora_alpha,
|
||||||
|
config.r,
|
||||||
|
uses_rslora=config.use_rslora,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names.discard(lora_a_name)
|
||||||
|
unused_weight_names.discard(lora_b_name)
|
||||||
|
|
||||||
|
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||||
|
# (A * B) * C = A * (B * C)
|
||||||
|
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||||
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
|
# pad lora ranks to be compatible with sgmv
|
||||||
|
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
|
||||||
|
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
|
||||||
|
|
||||||
|
if lora_a_list:
|
||||||
|
# update rank if it was padded
|
||||||
|
padded_rank = lora_a_list[0].size(1)
|
||||||
|
config.r = padded_rank
|
||||||
|
|
||||||
|
return LoraWeights(
|
||||||
|
*shard_lora_weights(
|
||||||
|
weights_a=lora_a_list,
|
||||||
|
weights_b=lora_b_list,
|
||||||
|
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
|
||||||
|
process_group=process_group,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankSegments:
|
||||||
|
rank: int
|
||||||
|
|
||||||
|
lora_a_ptr: torch.Tensor
|
||||||
|
lora_b_ptr: torch.Tensor
|
||||||
|
|
||||||
|
# prefill (sgmv)
|
||||||
|
tmp_shrink: torch.Tensor
|
||||||
|
tmp_expand: torch.Tensor
|
||||||
|
segment_starts: torch.Tensor
|
||||||
|
segment_ends: torch.Tensor
|
||||||
|
|
||||||
|
# decode (bgmv)
|
||||||
|
indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchLoraWeights(BatchAdapterWeights):
|
||||||
|
lora_a: Dict[int, torch.Tensor]
|
||||||
|
lora_b: Dict[int, torch.Tensor]
|
||||||
|
adapter_index_configs: Dict[int, LoraConfig]
|
||||||
|
rank_data: Dict[int, RankSegments]
|
||||||
|
use_sgmv: bool
|
||||||
|
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
return adapter_index in self.adapter_index_configs
|
||||||
|
|
||||||
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||||
|
return all(
|
||||||
|
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||||
|
for rank_data in self.rank_data.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Optional["BatchLoraWeights"]:
|
||||||
|
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||||
|
adapter_weights = {
|
||||||
|
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||||
|
}
|
||||||
|
if not adapter_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_weights = next(iter(adapter_weights.values()))
|
||||||
|
device = first_weights.weights_a.device
|
||||||
|
segment_indices = meta.segment_indices
|
||||||
|
|
||||||
|
lora_a = {
|
||||||
|
idx: adapter_weights[idx].weights_a
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
lora_b = {
|
||||||
|
idx: adapter_weights[idx].weights_b
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
max_rank = max(
|
||||||
|
(
|
||||||
|
adapter_weights[idx].lora_a_r
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
),
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill or max_rank > BGMV_MAX_RANK:
|
||||||
|
use_sgmv = True
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
use_sgmv = False
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index_configs = {
|
||||||
|
idx: adapter_weights[idx].adapter_config
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||||
|
|
||||||
|
rank_indices = defaultdict(list)
|
||||||
|
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||||
|
if adapter_idx not in adapter_weights:
|
||||||
|
continue
|
||||||
|
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||||
|
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||||
|
for head_index in prefill_head_indices:
|
||||||
|
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||||
|
if head_index < meta.adapter_segments[j]:
|
||||||
|
prefill_head_segment_ends[-1] += 1
|
||||||
|
else:
|
||||||
|
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||||
|
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
rank_data = {}
|
||||||
|
for rank, indices in rank_indices.items():
|
||||||
|
tmp_shrink = None
|
||||||
|
tmp_expand = None
|
||||||
|
segment_starts = None
|
||||||
|
segment_ends = None
|
||||||
|
batch_indices = None
|
||||||
|
|
||||||
|
if use_sgmv:
|
||||||
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||||
|
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||||
|
lora_a_ptr_indices.size(0), rank, device
|
||||||
|
)
|
||||||
|
segment_starts = meta.adapter_segments[indices]
|
||||||
|
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
for i, segment_index in enumerate(indices):
|
||||||
|
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||||
|
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||||
|
else:
|
||||||
|
rank_indices = set(indices)
|
||||||
|
batch_indices = [
|
||||||
|
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||||
|
]
|
||||||
|
batch_indices = [
|
||||||
|
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||||
|
]
|
||||||
|
batch_indices = torch.tensor(
|
||||||
|
batch_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_data[rank] = RankSegments(
|
||||||
|
rank=rank,
|
||||||
|
tmp_shrink=tmp_shrink,
|
||||||
|
tmp_expand=tmp_expand,
|
||||||
|
lora_a_ptr=lora_a_ptr[indices],
|
||||||
|
lora_b_ptr=lora_b_ptr[indices],
|
||||||
|
segment_starts=segment_starts,
|
||||||
|
segment_ends=segment_ends,
|
||||||
|
indices=batch_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchLoraWeights(
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
adapter_index_configs=adapter_index_configs,
|
||||||
|
rank_data=rank_data,
|
||||||
|
use_sgmv=use_sgmv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaling_factor(
|
||||||
|
lora_alpha: int,
|
||||||
|
r: int,
|
||||||
|
uses_rslora: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Computes the scaling factor for the lora weights."""
|
||||||
|
if uses_rslora:
|
||||||
|
return lora_alpha / (r**0.5)
|
||||||
|
return lora_alpha / r
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||||
|
if hasattr(v, "lora_weights"):
|
||||||
|
return v.lora_weights
|
||||||
|
return v
|
146
backends/gaudi/server/text_generation_server/adapters/weights.py
Normal file
146
backends/gaudi/server/text_generation_server/adapters/weights.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
# Origin: https://github.com/predibase/lorax
|
||||||
|
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||||
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
|
from abc import ABC, abstractclassmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchMetadata:
|
||||||
|
# [batch_size]
|
||||||
|
adapter_indices: torch.Tensor
|
||||||
|
|
||||||
|
# [num_adapters]
|
||||||
|
adapter_set: Set[int]
|
||||||
|
|
||||||
|
# [num_segments + 1]
|
||||||
|
adapter_segments: torch.Tensor
|
||||||
|
|
||||||
|
# [num_segments]
|
||||||
|
# maps from segment index to adapter index, i.e.:
|
||||||
|
# segment_indices[s] == adapter_indices[i]
|
||||||
|
segment_indices: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speculative_tokens(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchAdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: "AdapterBatchMetadata",
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: torch.Tensor,
|
||||||
|
) -> Optional["BatchAdapterWeights"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LayerAdapterWeights:
|
||||||
|
"""Adapter weights that apply to a particular layer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||||
|
|
||||||
|
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||||
|
self.adapter_weights[adapter_idx] = weights
|
||||||
|
|
||||||
|
def remove_adapter(self, adapter_idx: int):
|
||||||
|
if adapter_idx not in self.adapter_weights:
|
||||||
|
return
|
||||||
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
def get_data(
|
||||||
|
self,
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Dict[str, BatchAdapterWeights]:
|
||||||
|
# bucket adapters by batch class
|
||||||
|
adapter_batch_types: Dict[
|
||||||
|
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||||
|
] = defaultdict(dict)
|
||||||
|
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||||
|
for batch_type in adapter_weights.get_batch_types():
|
||||||
|
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
|
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||||
|
batched_weights = batch_type.load(
|
||||||
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
|
)
|
||||||
|
if batched_weights is not None:
|
||||||
|
batch_data = batched_weights
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchData:
|
||||||
|
meta: AdapterBatchMetadata
|
||||||
|
|
||||||
|
# layer type -> adapter type -> batch weight data
|
||||||
|
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||||
|
|
||||||
|
prefill: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_meta(
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
weights: Dict[str, LayerAdapterWeights],
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> "AdapterBatchData":
|
||||||
|
data = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if v.is_empty():
|
||||||
|
continue
|
||||||
|
data[k] = v.get_data(
|
||||||
|
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||||
|
)
|
||||||
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||||
|
|
||||||
|
def ranks(self) -> Set[int]:
|
||||||
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
|
ranks = set()
|
||||||
|
for lora_data in self.data.values():
|
||||||
|
if lora_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for rank_data in lora_data.rank_data.values():
|
||||||
|
ranks.add(rank_data.rank)
|
||||||
|
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
def layer_names(self) -> Set[str]:
|
||||||
|
return set(self.data.keys())
|
||||||
|
|
||||||
|
def adapter_keys(self) -> Set[str]:
|
||||||
|
adapter_keys = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
adapter_keys.update(layer_data.keys())
|
||||||
|
return adapter_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rank(self) -> int:
|
||||||
|
ranks = self.ranks()
|
||||||
|
return max(ranks) if len(ranks) > 0 else 0
|
34
backends/gaudi/server/text_generation_server/cache.py
Normal file
34
backends/gaudi/server/text_generation_server/cache.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
|
class Cache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: Dict[int, B] = {}
|
||||||
|
|
||||||
|
def pop(self, batch_id: int) -> Optional[B]:
|
||||||
|
return self.cache.pop(batch_id, None)
|
||||||
|
|
||||||
|
def set(self, entry: B):
|
||||||
|
if entry is not None:
|
||||||
|
self.cache[entry.batch_id] = entry
|
||||||
|
|
||||||
|
def delete(self, batch_id: int):
|
||||||
|
batch = self.pop(batch_id)
|
||||||
|
if batch is not None:
|
||||||
|
del batch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
keys = list(self.cache.keys())
|
||||||
|
for k in keys:
|
||||||
|
self.delete(k)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cache.keys())
|
426
backends/gaudi/server/text_generation_server/cli.py
Normal file
426
backends/gaudi/server/text_generation_server/cli.py
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class Quantization(str, Enum):
|
||||||
|
gptq = "gptq"
|
||||||
|
awq = "awq"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
|
class Dtype(str, Enum):
|
||||||
|
float16 = "float16"
|
||||||
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def serve(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
sharded: bool = False,
|
||||||
|
quantize: Optional[Quantization] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
|
dtype: Optional[Dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
otlp_service_name: str = "text-generation-inference.server",
|
||||||
|
max_input_tokens: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if sharded:
|
||||||
|
# assert (
|
||||||
|
# os.getenv("RANK", None) is not None
|
||||||
|
# ), "RANK must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("WORLD_SIZE", None) is not None
|
||||||
|
), "WORLD_SIZE must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_ADDR", None) is not None
|
||||||
|
), "MASTER_ADDR must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here after the logger is added to log potential import exceptions
|
||||||
|
from text_generation_server import server
|
||||||
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
|
||||||
|
# Setup OpenTelemetry distributed tracing
|
||||||
|
if otlp_endpoint is not None:
|
||||||
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||||
|
|
||||||
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
|
# and warn the user
|
||||||
|
if lora_adapters:
|
||||||
|
logger.warning("LoRA adapters enabled (experimental feature).")
|
||||||
|
|
||||||
|
if "CUDA_GRAPHS" in os.environ:
|
||||||
|
logger.warning(
|
||||||
|
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
|
||||||
|
)
|
||||||
|
global CUDA_GRAPHS
|
||||||
|
CUDA_GRAPHS = None
|
||||||
|
|
||||||
|
# Downgrade enum into str for easier management later on
|
||||||
|
quantize = None if quantize is None else quantize.value
|
||||||
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||||||
|
logger.info(f"quantize={quantize}")
|
||||||
|
if dtype is not None and quantize not in {
|
||||||
|
None,
|
||||||
|
"bitsandbytes",
|
||||||
|
"bitsandbytes-nf4",
|
||||||
|
"bitsandbytes-fp4",
|
||||||
|
"gptq",
|
||||||
|
"awq",
|
||||||
|
"fp8",
|
||||||
|
}:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||||
|
|
||||||
|
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
|
||||||
|
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||||
|
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
cmd = (
|
||||||
|
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||||
|
)
|
||||||
|
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||||
|
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||||
|
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||||||
|
if speculate is not None:
|
||||||
|
cmd += f"--speculate {speculate}"
|
||||||
|
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||||
|
do_terminate = False
|
||||||
|
current_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
||||||
|
def terminate_handler(sig, frame):
|
||||||
|
nonlocal do_terminate
|
||||||
|
do_terminate = True
|
||||||
|
if callable(current_handler):
|
||||||
|
current_handler(sig, frame)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, terminate_handler)
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
try:
|
||||||
|
if do_terminate:
|
||||||
|
parent = psutil.Process(proc.pid)
|
||||||
|
all_procs = parent.children(recursive=True) + [parent]
|
||||||
|
for p in all_procs:
|
||||||
|
try:
|
||||||
|
p.terminate()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||||
|
for p in alive:
|
||||||
|
p.kill()
|
||||||
|
|
||||||
|
do_terminate = False
|
||||||
|
|
||||||
|
proc.wait(timeout=3)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
finished = True
|
||||||
|
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
if proc.returncode != 0:
|
||||||
|
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||||
|
return proc.returncode
|
||||||
|
else:
|
||||||
|
server.serve(
|
||||||
|
model_id,
|
||||||
|
lora_adapters,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
uds_path,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def download_weights(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
extension: str = ".safetensors",
|
||||||
|
auto_convert: bool = True,
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
merge_lora: bool = False,
|
||||||
|
):
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here after the logger is added to log potential import exceptions
|
||||||
|
from text_generation_server import utils
|
||||||
|
|
||||||
|
# Test if files were already download
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info("Files are already present on the host. " "Skipping download.")
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||||
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if not is_local_model:
|
||||||
|
# TODO: maybe reverse the default value of merge_lora?
|
||||||
|
# currently by default we don't merge the weights with the base model
|
||||||
|
if merge_lora:
|
||||||
|
try:
|
||||||
|
hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
)
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
is_local_model = True
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
utils.peft.download_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
with open(config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
|
if base_model_id and base_model_id != model_id:
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
download_weights(
|
||||||
|
model_id=base_model_id,
|
||||||
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to download weights from the hub
|
||||||
|
try:
|
||||||
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
utils.download_weights(filenames, model_id, revision)
|
||||||
|
# Successfully downloaded weights
|
||||||
|
return
|
||||||
|
|
||||||
|
# No weights found on the hub with this extension
|
||||||
|
except utils.EntryNotFoundError as e:
|
||||||
|
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
|
||||||
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
elif (Path(model_id) / "adapter_config.json").exists():
|
||||||
|
# Try to load as a local PEFT model
|
||||||
|
try:
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
elif (Path(model_id) / "config.json").exists():
|
||||||
|
# Try to load as a local Medusa model
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = Path(model_id) / "config.json"
|
||||||
|
with open(config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||||||
|
if base_model_id:
|
||||||
|
try:
|
||||||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||||||
|
download_weights(
|
||||||
|
model_id=base_model_id,
|
||||||
|
revision="main",
|
||||||
|
extension=extension,
|
||||||
|
auto_convert=auto_convert,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to see if there are local pytorch weights
|
||||||
|
try:
|
||||||
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
|
try:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
|
except Exception:
|
||||||
|
local_pt_files = utils.weight_files(model_id, revision, ".pt")
|
||||||
|
|
||||||
|
# No local pytorch weights
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
if extension == ".safetensors":
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Downloading PyTorch weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to see if there are pytorch weights on the hub
|
||||||
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
|
# Download pytorch weights
|
||||||
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
|
||||||
|
if auto_convert:
|
||||||
|
if not trust_remote_code:
|
||||||
|
logger.warning(
|
||||||
|
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||||||
|
"Pickle files are unsafe and can essentially contain remote code execution!"
|
||||||
|
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights to safetensors."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safetensors final filenames
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
|
for p in local_pt_files
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
import json
|
||||||
|
|
||||||
|
if is_local_model:
|
||||||
|
config_filename = os.path.join(model_id, "config.json")
|
||||||
|
else:
|
||||||
|
config_filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
with open(config_filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
architecture = config["architectures"][0]
|
||||||
|
|
||||||
|
class_ = getattr(transformers, architecture)
|
||||||
|
|
||||||
|
# Name for this varible depends on transformers version.
|
||||||
|
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
discard_names = []
|
||||||
|
# Convert pytorch weights to safetensors
|
||||||
|
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def quantize(
|
||||||
|
model_id: str,
|
||||||
|
output_dir: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
upload_to_model_id: Optional[str] = None,
|
||||||
|
percdamp: float = 0.01,
|
||||||
|
act_order: bool = False,
|
||||||
|
groupsize: int = 128,
|
||||||
|
):
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
download_weights(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
logger_level=logger_level,
|
||||||
|
json_output=json_output,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.gptq.quantize import quantize
|
||||||
|
|
||||||
|
quantize(
|
||||||
|
model_id=model_id,
|
||||||
|
bits=4,
|
||||||
|
groupsize=groupsize,
|
||||||
|
output_dir=output_dir,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
upload_to_model_id=upload_to_model_id,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
|
sym=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||||
|
is_quantization_enabled = quant_config != ""
|
||||||
|
|
||||||
|
if is_quantization_enabled:
|
||||||
|
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||||
|
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||||
|
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||||
|
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||||
|
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||||
|
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_scoped_linear_all_reduce(model):
|
||||||
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
|
from optimum.habana.transformers.models.modeling_all_models import (
|
||||||
|
ScopedLinearAllReduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if type(module) is LinearAllreduce:
|
||||||
|
SL = ScopedLinearAllReduce(mod=module)
|
||||||
|
setattr(model, name, SL)
|
||||||
|
patch_scoped_linear_all_reduce(module)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
|
htorch.core.quantization._check_params_as_const(model)
|
||||||
|
htorch.core.hpu_initialize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_for_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
if model.config.model_type in [
|
||||||
|
"llama",
|
||||||
|
"falcon",
|
||||||
|
"qwen2",
|
||||||
|
"starcoder2",
|
||||||
|
"gemma",
|
||||||
|
]:
|
||||||
|
patch_scoped_linear_all_reduce(model)
|
||||||
|
from neural_compressor.torch.quantization import FP8Config, convert
|
||||||
|
|
||||||
|
config = FP8Config.from_json_file(quant_config)
|
||||||
|
model = convert(model, config)
|
||||||
|
return model
|
45
backends/gaudi/server/text_generation_server/interceptor.py
Normal file
45
backends/gaudi/server/text_generation_server/interceptor.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from google.rpc import status_pb2, code_pb2
|
||||||
|
from grpc_status import rpc_status
|
||||||
|
from grpc_interceptor.server import AsyncServerInterceptor
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Callable, Any
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
async def intercept(
|
||||||
|
self,
|
||||||
|
method: Callable,
|
||||||
|
request_or_iterator: Any,
|
||||||
|
context: grpc.ServicerContext,
|
||||||
|
method_name: str,
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
response = method(request_or_iterator, context)
|
||||||
|
return await response
|
||||||
|
except Exception as err:
|
||||||
|
trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
|
||||||
|
method_name = method_name.split("/")[-1]
|
||||||
|
logger.exception(f"Method {method_name} encountered an error.")
|
||||||
|
|
||||||
|
# Runtime Error cannot be recovered from
|
||||||
|
if isinstance(err, RuntimeError):
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
from .utils.debug import dbg_trace
|
||||||
|
|
||||||
|
dbg_trace("EXCEPTION", traceback.format_exc())
|
||||||
|
await context.abort_with_status(
|
||||||
|
rpc_status.to_status(
|
||||||
|
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,34 @@
|
|||||||
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.linear import (
|
||||||
|
get_linear,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.speculative import SpeculativeHead
|
||||||
|
|
||||||
|
# Just to add the `load` methods.
|
||||||
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
|
||||||
|
from text_generation_server.layers.lora import (
|
||||||
|
LoraLinear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_linear",
|
||||||
|
"FastLinear",
|
||||||
|
"TensorParallelColumnLinear",
|
||||||
|
"TensorParallelRowLinear",
|
||||||
|
"TensorParallelEmbedding",
|
||||||
|
"SpeculativeHead",
|
||||||
|
"LoraLinear",
|
||||||
|
"TensorParallelMultiAdapterLinear",
|
||||||
|
"TensorParallelAdapterRowLinear",
|
||||||
|
"load_layer_norm",
|
||||||
|
"load_conv2d",
|
||||||
|
]
|
@ -0,0 +1,28 @@
|
|||||||
|
from .common import (
|
||||||
|
Seqlen,
|
||||||
|
HPUPagedAttentionMetadata,
|
||||||
|
trim_attn_metadata,
|
||||||
|
trim_seqlen_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .hpu import (
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
|
from .kv_cache import KVCache, get_kv_scales
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"attention",
|
||||||
|
"get_kv_scales",
|
||||||
|
"paged_attention",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"KVCache",
|
||||||
|
"Seqlen",
|
||||||
|
"HPUPagedAttentionMetadata",
|
||||||
|
"trim_seqlen_metadata",
|
||||||
|
"trim_attn_metadata",
|
||||||
|
]
|
@ -0,0 +1,147 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
import collections
|
||||||
|
|
||||||
|
_TYPE_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HPUPagedAttentionMetadata:
|
||||||
|
"""Metadata for PagedAttention."""
|
||||||
|
|
||||||
|
block_list: Optional[torch.Tensor]
|
||||||
|
block_mapping: Optional[torch.Tensor]
|
||||||
|
block_usage: Optional[torch.Tensor]
|
||||||
|
block_scales: Optional[torch.Tensor]
|
||||||
|
block_groups: Optional[torch.Tensor]
|
||||||
|
attn_bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def subtuple(
|
||||||
|
obj: object,
|
||||||
|
typename: str,
|
||||||
|
to_copy: List[str],
|
||||||
|
to_override: Optional[Dict[str, object]] = None,
|
||||||
|
):
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if to_override is None:
|
||||||
|
to_override = {}
|
||||||
|
fields = set(to_copy) | set(to_override.keys())
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
values = {key: obj[key] for key in fields if key in obj}
|
||||||
|
else:
|
||||||
|
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||||
|
if typename not in _TYPE_CACHE:
|
||||||
|
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
|
||||||
|
return _TYPE_CACHE[typename](**values)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||||
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
|
|
||||||
|
# Before you put more keys in here, make sure you know their
|
||||||
|
# value type and make sure you know how it's going to be hashed.
|
||||||
|
# You can find that information in input_hash function
|
||||||
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedAttentionMetadata",
|
||||||
|
[
|
||||||
|
"block_list",
|
||||||
|
"block_mapping",
|
||||||
|
"block_usage",
|
||||||
|
"block_scales",
|
||||||
|
"block_groups",
|
||||||
|
"attn_bias",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
cache_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_lengths,
|
||||||
|
cache_lengths,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
):
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.cache_lengths = cache_lengths
|
||||||
|
device = self.input_lengths.device
|
||||||
|
shape = self.input_lengths.shape
|
||||||
|
if cu_seqlen_q is None:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
shape[0] + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
|
# Although FA2 might not want the clamping
|
||||||
|
# cu_seqlen_k[0] = 0
|
||||||
|
total = self.input_lengths + self.cache_lengths
|
||||||
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
|
||||||
|
def clamp(self, max):
|
||||||
|
# Flash decoding doesn't need to clamp
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||||
|
# NOTE(kzawora): To anyone working on this in the future:
|
||||||
|
# Trimming metadata is required when using HPUGraphs.
|
||||||
|
# Attention metadata is going to be hashed by PT bridge, and
|
||||||
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||||
|
|
||||||
|
# Before you put more keys in here, make sure you know their
|
||||||
|
# value type and make sure you know how it's going to be hashed.
|
||||||
|
# You can find that information in input_hash function
|
||||||
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||||
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||||
|
|
||||||
|
# If you use primitive types here - they will get hashed based
|
||||||
|
# on their value. You *will* get lots of excessive graph captures
|
||||||
|
# (and an OOM eventually) if you decide to put something like
|
||||||
|
# seq_len int here.
|
||||||
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||||
|
# get hashed using their metadata, not their values:
|
||||||
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||||
|
# input_hash(123) != input_hash(321)
|
||||||
|
# input_hash("abc") != input_hash("cba")
|
||||||
|
attention_metadata = subtuple(
|
||||||
|
metadata,
|
||||||
|
"TrimmedSeqlen",
|
||||||
|
[
|
||||||
|
"input_lengths",
|
||||||
|
"cache_lengths",
|
||||||
|
"cu_seqlen_q",
|
||||||
|
"cu_seqlen_k",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return attention_metadata
|
@ -0,0 +1,95 @@
|
|||||||
|
import torch
|
||||||
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
|
from vllm_hpu_extension import ops
|
||||||
|
from vllm_hpu_extension.utils import Matmul
|
||||||
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||||
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||||
|
import os
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_cache(cache, blocks):
|
||||||
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||||
|
return cache[: blocks.size(0)]
|
||||||
|
else:
|
||||||
|
return cache.index_select(0, blocks)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
*,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
):
|
||||||
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||||
|
bs = seqlen.input_lengths.shape[0]
|
||||||
|
_, head_num, head_size = query.shape
|
||||||
|
_, kv_head_num, head_size = key.shape
|
||||||
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||||
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||||
|
attn_output = fsdpa_op(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=softmax_scale,
|
||||||
|
softmax_mode="None",
|
||||||
|
recompute_mode=None,
|
||||||
|
valid_sequence_lengths=seqlen.input_lengths,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
|
):
|
||||||
|
batch_size, head_num, head_size = query.shape
|
||||||
|
output = ops.flat_pa(
|
||||||
|
query=query.view(batch_size, 1, head_num * head_size),
|
||||||
|
key_cache=kv_cache.key,
|
||||||
|
value_cache=kv_cache.value,
|
||||||
|
block_list=hpu_attention_meta.block_list,
|
||||||
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
|
block_scales=hpu_attention_meta.block_scales,
|
||||||
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
|
scale=softmax_scale,
|
||||||
|
matmul_qk_op=Matmul(),
|
||||||
|
matmul_av_op=Matmul(),
|
||||||
|
batch2block_matmul_op=Matmul(),
|
||||||
|
block2batch_matmul_op=Matmul(),
|
||||||
|
keys_fetch_func=fetch_from_cache,
|
||||||
|
values_fetch_func=fetch_from_cache,
|
||||||
|
)
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(batch_size, head_num, head_size)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
]
|
@ -0,0 +1,139 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from vllm_hpu_extension import cache_ops
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVScales:
|
||||||
|
"""
|
||||||
|
Key-value scales for FP8 KV cache.
|
||||||
|
|
||||||
|
This data class stores key and value scales both as a GPU tensor and
|
||||||
|
as a GPU float. This inconvenience is necessary because some functions
|
||||||
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||||
|
(e.g. flashinfer) take scales as a CPU scalar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scale: torch.Tensor
|
||||||
|
value_scale: torch.Tensor
|
||||||
|
key_scale_cpu: float = field(init=False)
|
||||||
|
value_scale_cpu: float = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||||
|
raise ValueError("Key and value scales must be scalar tensors.")
|
||||||
|
|
||||||
|
self.key_scale_cpu = self.key_scale.item()
|
||||||
|
self.value_scale_cpu = self.value_scale.item()
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""
|
||||||
|
Key-value cache for attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Construct the key-value cache for a layer."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache[0].dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the key cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Get the value cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[1]
|
||||||
|
|
||||||
|
def store(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
):
|
||||||
|
"""Store the key and value at the given slots."""
|
||||||
|
## TODO FP8 kv cache support
|
||||||
|
|
||||||
|
key_cache = self.kv_cache[0]
|
||||||
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
paged_reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
block_idx = slots // BLOCK_SIZE
|
||||||
|
block_offset = slots % BLOCK_SIZE
|
||||||
|
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||||
|
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
"""Load KV cache scales."""
|
||||||
|
|
||||||
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||||
|
value_scale = key_scale
|
||||||
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||||
|
f"{prefix}.v_scale"
|
||||||
|
):
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||||
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||||
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||||
|
# Fall back to older more coarse-grained scale when available.
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||||
|
value_scale = key_scale
|
||||||
|
|
||||||
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||||
|
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of packing, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||||
|
|
||||||
|
qmatrix = qmatrix.to(torch.int32)
|
||||||
|
|
||||||
|
return qmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
qmatrix (torch.Tensor): matrix of packed integers
|
||||||
|
direction (str): direction of unpacking, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, :, None], shifts[None, None, :]
|
||||||
|
).view(qmatrix.shape[0], -1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, None, :], shifts[None, :, None]
|
||||||
|
).view(-1, qmatrix.shape[-1])
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(
|
||||||
|
imatrix: torch.Tensor,
|
||||||
|
direction: str = "column",
|
||||||
|
order: List[int] = AWQ_PACK_ORDER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies the order to a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of applying order, either "column" or "row"
|
||||||
|
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
|
# awq uses column packing for both weights and zeros
|
||||||
|
izeros = unpack(qzeros, direction="column")
|
||||||
|
iweights = unpack(qweight, direction="column")
|
||||||
|
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
|
izeros = izeros - 1
|
||||||
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
|
qzeros = pack(izeros, direction="column")
|
||||||
|
qweight = pack(iweights, direction="row")
|
||||||
|
|
||||||
|
return qweight, qzeros
|
@ -0,0 +1,3 @@
|
|||||||
|
from .hpu import WQLinear
|
||||||
|
|
||||||
|
__all__ = ["WQLinear"]
|
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
import habana_frameworks.torch.hpu # noqa: F401
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||||
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
iweights = iweights.view(iweights.shape[0], -1)
|
||||||
|
|
||||||
|
# unpacking columnwise
|
||||||
|
if qzeros is not None:
|
||||||
|
izeros = torch.bitwise_right_shift(
|
||||||
|
qzeros[:, :, None], shifts[None, None, :]
|
||||||
|
).to(
|
||||||
|
torch.int8 # smallest dtype available
|
||||||
|
)
|
||||||
|
izeros = izeros.view(izeros.shape[0], -1)
|
||||||
|
else:
|
||||||
|
izeros = qzeros
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||||
|
reverse_order_tensor = torch.arange(
|
||||||
|
iweights.shape[-1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=izeros.device,
|
||||||
|
)
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||||
|
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||||
|
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||||
|
|
||||||
|
if izeros is not None:
|
||||||
|
izeros = izeros[:, reverse_order_tensor]
|
||||||
|
iweights = iweights[:, reverse_order_tensor]
|
||||||
|
|
||||||
|
return iweights, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_weight_and_zeros(qweight, qzeros, bits):
|
||||||
|
# Unpack the qweight and qzeros tensors
|
||||||
|
iweight, izeros = unpack_awq(qweight, qzeros, bits)
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
|
||||||
|
|
||||||
|
# overflow checks
|
||||||
|
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||||
|
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||||
|
|
||||||
|
return iweight, izeros
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros(
|
||||||
|
(normal.shape[0], normal.shape[1] // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class WQLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w_bit not in [4]:
|
||||||
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
|
self.in_features = qweight.shape[0]
|
||||||
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
self.w_bit = w_bit
|
||||||
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
|
# quick sanity check (make sure aligment)
|
||||||
|
assert self.in_features % self.group_size == 0
|
||||||
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
device = self.qweight.device
|
||||||
|
weight, zeros = unpack_weight_and_zeros(
|
||||||
|
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
|
||||||
|
)
|
||||||
|
self.qweight = pack_tensor(weight).to(device)
|
||||||
|
self.qzeros = pack_tensor(zeros).to(device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
outputs = torch.matmul(x, weights)
|
||||||
|
|
||||||
|
outputs = outputs + self.bias if self.bias is not None else outputs
|
||||||
|
outputs = outputs.reshape(out_shape)
|
||||||
|
return outputs
|
124
backends/gaudi/server/text_generation_server/layers/bnb.py
Normal file
124
backends/gaudi/server/text_generation_server/layers/bnb.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBWeight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear8bitLt(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=0.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
not memory_efficient_backward
|
||||||
|
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
# Necessary for stacked layers
|
||||||
|
self.state.threshold = threshold
|
||||||
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
self.weight = Int8Params(
|
||||||
|
weight.data,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
requires_grad=has_fp16_weights,
|
||||||
|
)
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def init_8bit_state(self):
|
||||||
|
self.state.CB = self.weight.CB
|
||||||
|
self.state.SCB = self.weight.SCB
|
||||||
|
self.weight.CB = None
|
||||||
|
self.weight.SCB = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
if self.weight.CB is not None:
|
||||||
|
self.init_8bit_state()
|
||||||
|
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
|
if not self.state.has_fp16_weights:
|
||||||
|
if self.state.CB is not None and self.state.CxB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del self.state.CB
|
||||||
|
self.weight.data = self.state.CxB
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBFP4Weight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear4bit(self.weight, bias, quant_type="fp4")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BNBNF4Weight(UnquantizedWeight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
return Linear4bit(self.weight, bias, quant_type="nf4")
|
||||||
|
|
||||||
|
|
||||||
|
class Linear4bit(torch.nn.Module):
|
||||||
|
def __init__(self, weight, bias, quant_type):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = Params4bit(
|
||||||
|
weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
compress_statistics=True,
|
||||||
|
quant_type=quant_type,
|
||||||
|
)
|
||||||
|
self.compute_dtype = None
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if getattr(self.weight, "quant_state", None) is None:
|
||||||
|
print(
|
||||||
|
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||||
|
)
|
||||||
|
inp_dtype = x.dtype
|
||||||
|
if self.compute_dtype is not None:
|
||||||
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
|
out = bnb.matmul_4bit(
|
||||||
|
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.to(inp_dtype)
|
||||||
|
|
||||||
|
return out
|
41
backends/gaudi/server/text_generation_server/layers/conv.py
Normal file
41
backends/gaudi/server/text_generation_server/layers/conv.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from accelerate import init_empty_weights
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
with init_empty_weights():
|
||||||
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight = torch.nn.Parameter(weight)
|
||||||
|
conv2d.bias = torch.nn.Parameter(bias)
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_conv2d_no_bias(
|
||||||
|
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||||
|
):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv2d.weight = torch.nn.Parameter(weight)
|
||||||
|
conv2d.bias = None
|
||||||
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
torch.nn.Conv2d.load = load_conv2d
|
||||||
|
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
|
78
backends/gaudi/server/text_generation_server/layers/exl2.py
Normal file
78
backends/gaudi/server/text_generation_server/layers/exl2.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Exl2Weight(Weight):
|
||||||
|
"""
|
||||||
|
Exllama2 exl2 quantized weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
q_weight: torch.Tensor
|
||||||
|
q_scale: torch.Tensor
|
||||||
|
q_invperm: torch.Tensor
|
||||||
|
q_scale_max: torch.Tensor
|
||||||
|
q_groups: torch.Tensor
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.q_scale_max /= 256
|
||||||
|
self.q_invperm = self.q_invperm.short()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.q_weight.device
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||||
|
|
||||||
|
return ExllamaQuantLinear(self, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class Exl2WeightsLoader(WeightsLoader):
|
||||||
|
"""Loader for exl2-quantized weights."""
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||||
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||||
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||||
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||||
|
|
||||||
|
return Exl2Weight(
|
||||||
|
q_weight=q_weight,
|
||||||
|
q_scale=q_scale,
|
||||||
|
q_invperm=q_invperm,
|
||||||
|
q_scale_max=q_scale_max,
|
||||||
|
q_groups=q_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||||
|
|
||||||
|
def get_weights_col(self, weights: Weights, prefix: str):
|
||||||
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
|
return self.get_weights(weights, prefix)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
|
return self.get_weights(weights, prefix)
|
458
backends/gaudi/server/text_generation_server/layers/fp8.py
Normal file
458
backends/gaudi/server/text_generation_server/layers/fp8.py
Normal file
@ -0,0 +1,458 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
WeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||||
|
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp
|
||||||
|
|
||||||
|
w8a8_block_fp8_matmul = None
|
||||||
|
per_token_group_quant_fp8 = None
|
||||||
|
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||||
|
"""
|
||||||
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
|
"""
|
||||||
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_native_float8(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
inv_scale: Union[float, torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = tensor.device
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||||
|
# dequant on cpu to avoid nan on gaudi2
|
||||||
|
tensor = tensor.to("cpu")
|
||||||
|
|
||||||
|
fake_qweight = tensor.to(dtype).to(device)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_quantize(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
scalar: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||||
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||||
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
|
be used without modification).
|
||||||
|
"""
|
||||||
|
shape = weight.shape
|
||||||
|
qweight, scale = scaled_fp8_quant(
|
||||||
|
weight.reshape(-1, shape[-1]),
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=scale_upper_bound,
|
||||||
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
|
use_per_token_if_dynamic=not scalar,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
|
|
||||||
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scale_ub: Optional[float],
|
||||||
|
to_fp8: bool,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
self.to_fp8 = to_fp8
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=dim)
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fp8Weight(Weight):
|
||||||
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
|
activation_scale_ub: Optional[float] = None
|
||||||
|
force_w8a16: bool = False
|
||||||
|
weight_block_size: Optional[List[int]] = None
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
if self.weight_scale is None:
|
||||||
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
|
||||||
|
self.weight, bias, self.dtype
|
||||||
|
)
|
||||||
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
|
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
|
||||||
|
weight=self.weight,
|
||||||
|
scale=self.weight_scale,
|
||||||
|
dtype=self.dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=self.input_scale,
|
||||||
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8Linear(torch.nn.Module):
|
||||||
|
_device_identity_cache = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qweight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[float] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.qweight = qweight
|
||||||
|
self.scale = scale.float()
|
||||||
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight, scalar=True)
|
||||||
|
return cls(
|
||||||
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
qweight=weight,
|
||||||
|
scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
scale_upper_bound=scale_upper_bound,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_device_identity(cls, device):
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
if device not in cls._device_identity_cache:
|
||||||
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||||
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
# https://arxiv.org/pdf/2412.19437
|
||||||
|
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||||
|
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||||
|
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||||
|
# channels).
|
||||||
|
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||||
|
output = w8a8_block_fp8_matmul(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
self.weight_block_size,
|
||||||
|
output_dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output.to(dtype=input.dtype)
|
||||||
|
|
||||||
|
qinput, scale = fp8_quantize(
|
||||||
|
input,
|
||||||
|
self.input_scale,
|
||||||
|
scale_upper_bound=self.scale_upper_bound,
|
||||||
|
scalar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
self.qweight.t(),
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
scale_a=scale,
|
||||||
|
scale_b=self.scale,
|
||||||
|
bias=self.bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
return scale.reshape(-1)
|
@ -0,0 +1,357 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
from .hpu import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQWeight(Weight):
|
||||||
|
qweight: torch.Tensor
|
||||||
|
qzeros: torch.Tensor
|
||||||
|
scales: torch.Tensor
|
||||||
|
g_idx: Optional[torch.Tensor]
|
||||||
|
bits: int
|
||||||
|
groupsize: int
|
||||||
|
use_awq_kernel: bool
|
||||||
|
use_exllama: bool
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.scales.dtype == torch.float:
|
||||||
|
self.scales = self.scales.half()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.qweight.device
|
||||||
|
|
||||||
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
if self.use_awq_kernel:
|
||||||
|
try:
|
||||||
|
from text_generation_server.layers.awq.quantize import WQLinear
|
||||||
|
|
||||||
|
return WQLinear(
|
||||||
|
w_bit=self.bits,
|
||||||
|
group_size=self.groupsize,
|
||||||
|
qweight=self.qweight,
|
||||||
|
qzeros=self.qzeros,
|
||||||
|
scales=self.scales,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return QuantLinear(
|
||||||
|
self.qweight,
|
||||||
|
self.qzeros,
|
||||||
|
self.scales,
|
||||||
|
self.g_idx,
|
||||||
|
bias,
|
||||||
|
self.bits,
|
||||||
|
self.groupsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQWeightsLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for GPTQ- and AWQ-quantized weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bits: int,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
):
|
||||||
|
self.bits = bits
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quantize = quantize
|
||||||
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if use_exllama and g_idx is not None:
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
qweight = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
scales = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
desc_act = self.desc_act
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
if weights.process_group.size() > 1:
|
||||||
|
if g_idx is not None:
|
||||||
|
if (
|
||||||
|
not torch.equal(
|
||||||
|
# Remove g_idx[0] to adapt the check with TP>1.
|
||||||
|
(g_idx - g_idx[0]).cpu(),
|
||||||
|
torch.tensor(
|
||||||
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and not (g_idx == 0).all()
|
||||||
|
):
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
use_exllama = False
|
||||||
|
desc_act = True
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import (
|
||||||
|
GPTQWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not desc_act and self.groupsize != -1:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
if g_idx is not None:
|
||||||
|
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_gptq_params(self, weights: Weights):
|
||||||
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
self.desc_act = False
|
||||||
|
# `server quantize` used asymmetric quantization unconditionally
|
||||||
|
# before the `gptq_sym` setting tensor was added.
|
||||||
|
self.sym = (
|
||||||
|
weights.get_tensor("gptq_sym").item()
|
||||||
|
if weights.has_tensor("gptq_sym")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.quant_method = "gptq"
|
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
186
backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||||
|
except Exception as e:
|
||||||
|
hpu_import_exception = e
|
||||||
|
|
||||||
|
def error_raiser_hpu(*args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_from_uint4 = error_raiser_hpu
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor(input, bits=4):
|
||||||
|
normal = input.to(torch.int32)
|
||||||
|
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < q.shape[1]:
|
||||||
|
for j in range(i, i + (32 // bits)):
|
||||||
|
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||||
|
i += 32 // bits
|
||||||
|
col += 1
|
||||||
|
q = q.to(torch.int32)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("qweight", qweight)
|
||||||
|
self.register_buffer("qzeros", qzeros)
|
||||||
|
self.register_buffer("scales", scales)
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
if bias is not None:
|
||||||
|
self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
self.bits = bits
|
||||||
|
self.maxq = 2**self.bits - 1
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.infeatures = qweight.shape[0] * 32 // bits
|
||||||
|
self.wf = torch.tensor(
|
||||||
|
list(range(0, 32, self.bits)), dtype=torch.int32
|
||||||
|
).unsqueeze(0)
|
||||||
|
self._preprocessing()
|
||||||
|
|
||||||
|
def unpack_zeros_from_cuda_old_format(self):
|
||||||
|
zeros = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||||
|
self.wf.unsqueeze(0),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
|
||||||
|
zeros = zeros + 1
|
||||||
|
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
|
||||||
|
self.scales.dtype
|
||||||
|
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
|
||||||
|
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
|
||||||
|
return zeros
|
||||||
|
|
||||||
|
def unpack_weight_from_cuda_old_format(self):
|
||||||
|
weight = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||||
|
self.wf.unsqueeze(-1),
|
||||||
|
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||||
|
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
|
||||||
|
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _preprocessing(self):
|
||||||
|
orig_device = self.qweight.device
|
||||||
|
self.qweight = self.qweight.cpu()
|
||||||
|
weight = self.unpack_weight_from_cuda_old_format()
|
||||||
|
new_qweight = pack_tensor(weight)
|
||||||
|
self.qweight = new_qweight.to(orig_device)
|
||||||
|
# TODO: Support group indexing and remove the check
|
||||||
|
columns = self.qweight.shape[0]
|
||||||
|
g_idx_trivial = [i // self.groupsize for i in range(columns)]
|
||||||
|
g_idx_trivial = torch.tensor(
|
||||||
|
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
self.g_idx, g_idx_trivial
|
||||||
|
), "Non-trivial tensor g_idx is not supported"
|
||||||
|
self.qzeros = self.qzeros.cpu()
|
||||||
|
zeros = self.unpack_zeros_from_cuda_old_format()
|
||||||
|
new_qzeros = pack_tensor(zeros)
|
||||||
|
self.qzeros = new_qzeros.to(orig_device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||||
|
qzeros = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
scales = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||||
|
)
|
||||||
|
g_idx = torch.tensor(
|
||||||
|
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
|
|
||||||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||||
|
|
||||||
|
scales = scales.t().contiguous()
|
||||||
|
zeros = zeros.t().contiguous()
|
||||||
|
scale_zeros = zeros * scales
|
||||||
|
self.scales = scales.clone().half()
|
||||||
|
if linear.bias is not None:
|
||||||
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
|
intweight = []
|
||||||
|
for idx in range(self.infeatures):
|
||||||
|
intweight.append(
|
||||||
|
torch.round(
|
||||||
|
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||||
|
/ self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
|
intweight = torch.cat(intweight, dim=1)
|
||||||
|
intweight = intweight.t().contiguous()
|
||||||
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
qweight = np.zeros(
|
||||||
|
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
row = 0
|
||||||
|
while row < qweight.shape[0]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
row += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = qweight.astype(np.int32)
|
||||||
|
self.qweight = torch.from_numpy(qweight)
|
||||||
|
|
||||||
|
zeros -= 1
|
||||||
|
zeros = zeros.numpy().astype(np.uint32)
|
||||||
|
qzeros = np.zeros(
|
||||||
|
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < qzeros.shape[1]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
col += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qzeros = qzeros.astype(np.int32)
|
||||||
|
self.qzeros = torch.from_numpy(qzeros)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||||
|
out = torch.matmul(x, weight)
|
||||||
|
out = out.reshape(out_shape)
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out
|
1026
backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
Normal file
1026
backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
|
||||||
|
def torch_snr_error(
|
||||||
|
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute SNR between y_pred(tensor) and y_real(tensor)
|
||||||
|
|
||||||
|
SNR can be calcualted as following equation:
|
||||||
|
|
||||||
|
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
|
||||||
|
|
||||||
|
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
|
||||||
|
|
||||||
|
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred (torch.Tensor): _description_
|
||||||
|
y_real (torch.Tensor): _description_
|
||||||
|
reduction (str, optional): _description_. Defaults to 'mean'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
ValueError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: _description_
|
||||||
|
"""
|
||||||
|
if y_pred.shape != y_real.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Can not compute snr loss for tensors with different shape. "
|
||||||
|
f"({y_pred.shape} and {y_real.shape})"
|
||||||
|
)
|
||||||
|
reduction = str(reduction).lower()
|
||||||
|
|
||||||
|
if y_pred.ndim == 1:
|
||||||
|
y_pred = y_pred.unsqueeze(0)
|
||||||
|
y_real = y_real.unsqueeze(0)
|
||||||
|
|
||||||
|
y_pred = y_pred.flatten(start_dim=1)
|
||||||
|
y_real = y_real.flatten(start_dim=1)
|
||||||
|
|
||||||
|
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
|
||||||
|
signal_power = torch.pow(y_real, 2).sum(dim=-1)
|
||||||
|
snr = (noise_power) / (signal_power + 1e-7)
|
||||||
|
|
||||||
|
if reduction == "mean":
|
||||||
|
return torch.mean(snr)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return torch.sum(snr)
|
||||||
|
elif reduction == "none":
|
||||||
|
return snr
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported reduction method.")
|
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
|
# Monkey patching
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = torch.nn.Parameter(weight)
|
||||||
|
ln.bias = torch.nn.Parameter(bias)
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = torch.nn.Parameter(weight)
|
||||||
|
ln.bias = None
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
|
|
||||||
|
class FastRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
return cls(weight, eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
from vllm_hpu_extension.kernels import rms_norm
|
||||||
|
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
if residual is not None:
|
||||||
|
residual += hidden_states.view(residual.shape)
|
||||||
|
else:
|
||||||
|
residual = hidden_states
|
||||||
|
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||||
|
if len(orig_shape) == 2:
|
||||||
|
residual = residual.unsqueeze(0)
|
||||||
|
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||||
|
return x.view(orig_shape), residual.view(orig_shape)
|
@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FastLinear(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(weight, bias)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear(weight, bias):
|
||||||
|
# Weights that are loaded through methods that are not
|
||||||
|
# quantization-aware are still bare tensors. We may want
|
||||||
|
# to change this in the future.
|
||||||
|
if isinstance(weight, torch.Tensor):
|
||||||
|
return FastLinear(weight, bias)
|
||||||
|
|
||||||
|
return weight.get_linear(bias)
|
279
backends/gaudi/server/text_generation_server/layers/lora.py
Normal file
279
backends/gaudi/server/text_generation_server/layers/lora.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
add_lora_a_bgmv,
|
||||||
|
add_lora_b_bgmv,
|
||||||
|
has_sgmv,
|
||||||
|
lora_a_sgmv_cutlass,
|
||||||
|
lora_b_sgmv_cutlass,
|
||||||
|
orient_for_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward_layer_type(
|
||||||
|
self,
|
||||||
|
result: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
adapter_data: "AdapterBatchData",
|
||||||
|
layer_type: str,
|
||||||
|
start_idx: int,
|
||||||
|
end_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if adapter_data is None:
|
||||||
|
return result
|
||||||
|
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
|
||||||
|
|
||||||
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||||
|
# The 'result' tensor represents the full output, which can vary in size based on
|
||||||
|
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||||
|
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||||
|
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||||
|
# This approach ensures accurate LoRA application across various layer sizes and
|
||||||
|
# configurations, adapting to different model architectures and parallelization strategies.
|
||||||
|
#
|
||||||
|
# Example scenarios where this is necessary:
|
||||||
|
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||||
|
# 2. We're processing the last segment which might be smaller.
|
||||||
|
# 3. Different projection layers (q, k, v) have different sizes.
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
|
else:
|
||||||
|
proj = result
|
||||||
|
|
||||||
|
for r, rank_segments in data.rank_data.items():
|
||||||
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
|
if lora_a_ptr is None or lora_b_ptr is None:
|
||||||
|
raise ValueError("LoRA data is missing")
|
||||||
|
|
||||||
|
if data.use_sgmv:
|
||||||
|
# Use SGMV for prefill
|
||||||
|
v = lora_a_sgmv_cutlass(
|
||||||
|
input,
|
||||||
|
rank_segments.tmp_shrink,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
lora_b_sgmv_cutlass(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
rank_segments.tmp_expand,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use BGMV for decode
|
||||||
|
v = torch.zeros(
|
||||||
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
|
# TODO: error with [-1, 0], but not [0, -1]
|
||||||
|
add_lora_a_bgmv(
|
||||||
|
v,
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
add_lora_b_bgmv(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
result[:, start_idx:end_idx] += proj
|
||||||
|
else:
|
||||||
|
for adapter_index in adapter_data.meta.adapter_set:
|
||||||
|
if data is not None and data.has_adapter(adapter_index):
|
||||||
|
adapter_mask = (
|
||||||
|
(adapter_data.meta.adapter_indices == adapter_index)
|
||||||
|
.to(input.dtype)
|
||||||
|
.view(-1, 1)
|
||||||
|
)
|
||||||
|
layer_result = self.forward_lora(
|
||||||
|
input, data, adapter_index, adapter_mask
|
||||||
|
)
|
||||||
|
result[:, start_idx:end_idx] += layer_result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward_lora(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
data: "BatchLoraWeights",
|
||||||
|
adapter_index: int,
|
||||||
|
adapter_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||||
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||||
|
|
||||||
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
|
a_out = input @ lora_a
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
a_out = self.collect_lora_a(a_out)
|
||||||
|
|
||||||
|
result = (a_out @ lora_b) * adapter_mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("Implemented in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.sizes = sizes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
layer_id: int,
|
||||||
|
layer_names: List[str],
|
||||||
|
sizes: List[int],
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
return TensorParallelMultiAdapterLinear(
|
||||||
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# noop if no layer names are provided (e.g. for models without adapters)
|
||||||
|
if self.layer_names is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# handle models like Bloom that have inputs of shape
|
||||||
|
# (batch_size, sequence_length, hidden_size)
|
||||||
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
|
# for the LoRA computation, then reshape back
|
||||||
|
prev_shape = result.shape
|
||||||
|
is_3d = len(input.shape) >= 3
|
||||||
|
if is_3d:
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
|
start_idx = offset // self.process_group.size()
|
||||||
|
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||||
|
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||||
|
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||||
|
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||||
|
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||||
|
# different projection sizes across layers and model architectures.
|
||||||
|
if self.sizes is not None:
|
||||||
|
offset += self.sizes[i]
|
||||||
|
end_idx = offset // self.process_group.size()
|
||||||
|
else:
|
||||||
|
end_idx = result.shape[1]
|
||||||
|
|
||||||
|
result = self.forward_layer_type(
|
||||||
|
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_3d:
|
||||||
|
result = result.reshape(prev_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||||
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
gathered_tensors = [
|
||||||
|
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||||
|
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||||
|
return cls(base_layer, layer_id, layer_name, process_group)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
if self.layer_name is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||||
|
stride = result.shape[-1] // self.process_group.size()
|
||||||
|
start_idx = self.process_group.rank() * stride
|
||||||
|
end_idx = (self.process_group.rank() + 1) * stride
|
||||||
|
|
||||||
|
self.forward_layer_type(
|
||||||
|
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||||
|
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||||
|
return a_out
|
191
backends/gaudi/server/text_generation_server/layers/medusa.py
Normal file
191
backends/gaudi/server/text_generation_server/layers/medusa.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
from text_generation_server.layers.linear import FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
|
TensorParallelHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, medusa_config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
|
for i in range(get_speculate())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHead(torch.nn.Module):
|
||||||
|
def __init__(self, config, medusa_config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
|
for i in range(medusa_config["medusa_num_layers"])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
n = len(self.blocks)
|
||||||
|
self.out = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHeadV1(nn.Module):
|
||||||
|
def __init__(self, lm_head, medusa):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.medusa = medusa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
speculator = config.speculator
|
||||||
|
|
||||||
|
path = speculator["path"]
|
||||||
|
medusa_config = str(Path(path) / "config.json")
|
||||||
|
|
||||||
|
for fname in speculator["model_paths"]:
|
||||||
|
filename = str(Path(path) / fname)
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MedusaHeadV1(lm_head, medusa)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if input.shape[0] > 128:
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
speculative_logits = self.medusa(input)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHeadV2(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
|
medusa_config = str(Path(speculator_path) / "config.json")
|
||||||
|
filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
self.n_medusa_heads = get_speculate()
|
||||||
|
|
||||||
|
assert medusa_config["medusa_num_layers"] == 1
|
||||||
|
self.linear = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
|
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if x.shape[0] > 128:
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
size = x.shape[-1]
|
||||||
|
block_size = (size + self.world_size - 1) // self.world_size
|
||||||
|
start = self.rank * block_size
|
||||||
|
stop = (self.rank + 1) * block_size
|
||||||
|
|
||||||
|
x_block = x[:, start:stop]
|
||||||
|
|
||||||
|
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||||
|
medusa_res = self.act(self.linear(x)).reshape(
|
||||||
|
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply all residual medusa heads
|
||||||
|
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||||
|
|
||||||
|
# Gather medusa heads
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
|
||||||
|
# Stack x and medusa residual x
|
||||||
|
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||||
|
|
||||||
|
# Compute lm head on x + medusa residual x
|
||||||
|
logits = self.lm_head(stacked_x)
|
||||||
|
|
||||||
|
# Finally, split logits from speculative logits
|
||||||
|
logits, speculative_logits = torch.split(
|
||||||
|
logits, [1, self.n_medusa_heads], dim=-2
|
||||||
|
)
|
||||||
|
# Squeeze added dimension
|
||||||
|
logits = logits.squeeze(-2)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
282
backends/gaudi/server/text_generation_server/layers/mlp.py
Normal file
282
backends/gaudi/server/text_generation_server/layers/mlp.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
A L2 normalization implementation
|
||||||
|
...
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
normalized_shape : int
|
||||||
|
Dimensionality of input data (size of final tensor axis)
|
||||||
|
elementwise_scale_weight : torch.Tensor
|
||||||
|
learned scaling term after normalization?
|
||||||
|
elementwise_shift_bias : torch.Tensor
|
||||||
|
learned bias term after normalization?
|
||||||
|
eps : float
|
||||||
|
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
eps=1e-06,
|
||||||
|
):
|
||||||
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
|
self.weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
self.bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
x = self.weight * x
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
INV_SQRT2 = 2**-0.5
|
||||||
|
|
||||||
|
|
||||||
|
def simple_norm(x: torch.Tensor, eps=1e-06):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
return x * INV_SQRT2
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
|
||||||
|
self.proj0 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.0",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj1 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
|
||||||
|
self.ln = MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.0",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb(ind)
|
||||||
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
|
if i == 0:
|
||||||
|
state = self.proj0(state) * self.state_weight + z
|
||||||
|
else:
|
||||||
|
state = self.proj1(state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln(state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head(state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.proj = [
|
||||||
|
FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.{i}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
self.head = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb[i](ind)
|
||||||
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorHead(nn.Module):
|
||||||
|
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.mlp_speculator = mlp_speculator
|
||||||
|
self.scale_input = scale_input
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if input.shape[0] > 128:
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
input_ids = logits.argmax(dim=-1)
|
||||||
|
if self.scale_input:
|
||||||
|
input = simple_norm(input)
|
||||||
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
|
for fname in config.speculator["model_paths"]:
|
||||||
|
filename = str(Path(speculator_path) / fname)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
tie_weights = config.speculator_config.get("tie_weights", False)
|
||||||
|
if tie_weights:
|
||||||
|
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||||
|
else:
|
||||||
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
|
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
|
||||||
|
scale_input = config.speculator_config.get("scale_input", False)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
|
@ -0,0 +1,250 @@
|
|||||||
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
Weights,
|
||||||
|
UnquantizedWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||||
|
# class inheritance is whacky.
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class MoELayer(Protocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
): ...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||||
|
their outputs based on the calculated routing. This layer is much slower
|
||||||
|
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||||
|
available (e.g. for unsupported quantizers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert scoring_func is None, "scoring func is not handled"
|
||||||
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.n_experts = n_experts
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if "gelu" in hidden_act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in hidden_act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
|
self.gate_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.up_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.down_proj = [
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gating_output: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
gating_output,
|
||||||
|
self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
x, gating_output, self.topk, self.renormalize
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
weights = torch.zeros(
|
||||||
|
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||||
|
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for i in range(self.n_experts):
|
||||||
|
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||||
|
h = self.down_proj[i](h, reduce=False)
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that uses fused kernels to only apply the active experts
|
||||||
|
for each token (rather than applying all experts and selecting the
|
||||||
|
outputs of active experts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
and weights.loader.to_fp8
|
||||||
|
):
|
||||||
|
cls = FP8SparseMoELayer
|
||||||
|
else:
|
||||||
|
cls = UnquantizedSparseMoELayer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Using MoE layer wih fused gemm",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moe = cls(
|
||||||
|
n_expert_group=n_expert_group,
|
||||||
|
n_experts=n_experts,
|
||||||
|
prefix=prefix,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk=topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
down_proj_name=down_proj_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.moe(x, gating_output=gating_output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_supported(weights: Weights) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader)
|
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
fp8_quantize,
|
||||||
|
quant_dtype,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .unquantized import fused_moe
|
||||||
|
except Exception:
|
||||||
|
fused_moe = None
|
||||||
|
|
||||||
|
|
||||||
|
class FP8SparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.gate_up_proj_input_scale,
|
||||||
|
) = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||||
|
_load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
|
w2_scale=self.down_proj_weight_scale,
|
||||||
|
a1_scale=self.gate_up_proj_input_scale,
|
||||||
|
a2_scale=self.down_proj_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights(
|
||||||
|
get_weight_fn,
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
all_weight_scales = None
|
||||||
|
max_input_scale = None
|
||||||
|
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = get_weight_fn(prefix, i, name, weights)
|
||||||
|
|
||||||
|
assert isinstance(weight, Fp8Weight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=quant_dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
if all_weight_scales is None:
|
||||||
|
all_weight_scales = torch.empty(
|
||||||
|
(n_experts,) + weight.weight_scale.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
|
normalize_e4m3fn_to_native_float8(
|
||||||
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if current_input_scale is not None:
|
||||||
|
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||||
|
max_input_scale = current_input_scale
|
||||||
|
else:
|
||||||
|
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||||
|
weight.weight, scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight, all_weight_scales, max_input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||||
|
)
|
@ -0,0 +1,65 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights = torch.nn.functional.softmax(
|
||||||
|
gating_output, dim=1, dtype=torch.float32
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
@ -0,0 +1,121 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = _load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||||
|
for i in range(n_experts):
|
||||||
|
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
|
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_weights_row(
|
||||||
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
606
backends/gaudi/server/text_generation_server/layers/rotary.py
Normal file
606
backends/gaudi/server/text_generation_server/layers/rotary.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from habana_frameworks.torch.hpex.kernels import (
|
||||||
|
RotaryPosEmbeddingMode,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_inv_freq(dim, base, device):
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rope_config(config):
|
||||||
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
|
rope_scaling = {
|
||||||
|
"type": os.environ["ROPE_SCALING"],
|
||||||
|
"factor": float(os.environ["ROPE_FACTOR"]),
|
||||||
|
}
|
||||||
|
return rope_scaling
|
||||||
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
|
||||||
|
super().__init__()
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
||||||
|
# to query hidden dimension, so the original tensors need to be
|
||||||
|
# expanded
|
||||||
|
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
||||||
|
# and expansion of cos/sin tensors via concatenation
|
||||||
|
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||||
|
cos = torch.cat((cos, cos), dim=-1)
|
||||||
|
sin = torch.cat((sin, sin), dim=-1)
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
query_pass = query[..., rotary_dim:]
|
||||||
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
key_pass = key[..., rotary_dim:]
|
||||||
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||||
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def static(cls, config, dim, base, device):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
scaling_factor = None
|
||||||
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if not hasattr(config, "max_position_embeddings") and hasattr(
|
||||||
|
config, "max_seq_len"
|
||||||
|
):
|
||||||
|
# handling for dbrx
|
||||||
|
config.max_position_embeddings = config.max_seq_len
|
||||||
|
if rope_scaling is not None:
|
||||||
|
# `rope_type` is now standard in transformers, but some existing models
|
||||||
|
# have `type` instead.
|
||||||
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
||||||
|
|
||||||
|
if rope_type == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_type == "default":
|
||||||
|
pass
|
||||||
|
elif rope_type == "mrope":
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
if mrope_section is not None:
|
||||||
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
mrope_section,
|
||||||
|
config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
elif rope_type == "dynamic":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=base,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif rope_type == "llama3":
|
||||||
|
inv_freq = apply_llama3_scaling(
|
||||||
|
inv_freq,
|
||||||
|
scaling_factor=rope_scaling["factor"],
|
||||||
|
low_freq_factor=rope_scaling["low_freq_factor"],
|
||||||
|
high_freq_factor=rope_scaling["high_freq_factor"],
|
||||||
|
original_max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
elif rope_type == "yarn":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
base=base,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
|
)
|
||||||
|
elif rope_type in ["su", "longrope"]:
|
||||||
|
short_factor = torch.tensor(
|
||||||
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
short_inv_freq = 1.0 / (
|
||||||
|
short_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
long_factor = torch.tensor(
|
||||||
|
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
long_inv_freq = 1.0 / (
|
||||||
|
long_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
original_max_position_embeddings = (
|
||||||
|
config.original_max_position_embeddings
|
||||||
|
)
|
||||||
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
if max_position_embeddings <= original_max_position_embeddings:
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scale = max_position_embeddings / original_max_position_embeddings
|
||||||
|
scaling_factor = math.sqrt(
|
||||||
|
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||||
|
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||||
|
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||||
|
short_mscale = rope_scaling["short_mscale"]
|
||||||
|
long_mscale = rope_scaling["long_mscale"]
|
||||||
|
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
short_mscale=short_mscale,
|
||||||
|
long_mscale=long_mscale,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SuRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix, weights):
|
||||||
|
# XXX: Always load this in float32 !
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
|
weights.dtype = dtype
|
||||||
|
|
||||||
|
scaling_factor = None
|
||||||
|
rope_scaling = _get_rope_config(config)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "linear":
|
||||||
|
pass
|
||||||
|
elif rope_scaling["type"] == "dynamic":
|
||||||
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
t /= self.scaling_factor
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||||
|
|
||||||
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
|
|
||||||
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
short_inv_freq,
|
||||||
|
long_inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
original_max_position_embeddings,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
|
self.short_inv_freq = short_inv_freq
|
||||||
|
self.long_inv_freq = long_inv_freq
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached is None
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
|
short_freqs = torch.outer(
|
||||||
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.cat([short_freqs, long_freqs])
|
||||||
|
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
short_inv_freq: torch.Tensor,
|
||||||
|
long_inv_freq: torch.Tensor,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
short_mscale: float,
|
||||||
|
long_mscale: float,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
):
|
||||||
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
|
self.short_inv_freq = short_inv_freq
|
||||||
|
self.long_inv_freq = long_inv_freq
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.short_mscale = short_mscale
|
||||||
|
self.long_mscale = long_mscale
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
|
||||||
|
# cache
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached is None
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
|
|
||||||
|
short_freqs = torch.outer(
|
||||||
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
short_freqs = short_freqs * self.short_mscale
|
||||||
|
long_freqs = long_freqs * self.long_mscale
|
||||||
|
|
||||||
|
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
||||||
|
freqs[: self.original_max_position_embeddings] = short_freqs
|
||||||
|
freqs[self.original_max_position_embeddings :] = long_freqs
|
||||||
|
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings:
|
||||||
|
newbase = self.base * (
|
||||||
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||||
|
- (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
self.inv_freq = _create_inv_freq(
|
||||||
|
self.dim, newbase, self.inv_freq.device
|
||||||
|
)
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||||
|
2 * math.log(base)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Find dim range bounds based on rotations
|
||||||
|
def find_correction_range(
|
||||||
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||||
|
):
|
||||||
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||||
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||||
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
|
def linear_ramp_mask(min, max, dim):
|
||||||
|
if min == max:
|
||||||
|
max += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
device,
|
||||||
|
scaling_factor,
|
||||||
|
*,
|
||||||
|
extrapolation_factor,
|
||||||
|
attn_factor,
|
||||||
|
beta_fast,
|
||||||
|
beta_slow,
|
||||||
|
mscale: float,
|
||||||
|
mscale_all_dim: float,
|
||||||
|
):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(
|
||||||
|
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||||
|
)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
self.mscale_all_dim = mscale_all_dim
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.mscale = float(
|
||||||
|
get_mscale(self.scaling_factor, mscale)
|
||||||
|
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||||
|
* self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings or True:
|
||||||
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
|
self.dim, self.base, self.inv_freq.device
|
||||||
|
)
|
||||||
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
|
low, high = find_correction_range(
|
||||||
|
self.beta_fast,
|
||||||
|
self.beta_slow,
|
||||||
|
self.dim,
|
||||||
|
self.base,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
inv_freq_mask = (
|
||||||
|
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_llama3_scaling(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
*,
|
||||||
|
scaling_factor: int,
|
||||||
|
low_freq_factor: int,
|
||||||
|
high_freq_factor: int,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
):
|
||||||
|
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
||||||
|
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / scaling_factor)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||||
|
|
||||||
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inv_freq: torch.Tensor,
|
||||||
|
scaling_factor: float,
|
||||||
|
sections: list,
|
||||||
|
max_position_embeddings,
|
||||||
|
):
|
||||||
|
self.sections = sections
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self.section_indices = (
|
||||||
|
torch.arange(len(self.sections))
|
||||||
|
.repeat_interleave(torch.tensor(self.sections))
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.to(inv_freq.device)
|
||||||
|
)
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||||
|
):
|
||||||
|
# always cache the cos/sin for the full sequence length to avoid
|
||||||
|
# recomputing if the sequence length is smaller than the cached one
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
||||||
|
|
||||||
|
def get_cos_sin(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
slen = position_ids.shape[0]
|
||||||
|
|
||||||
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
return cos, sin
|
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
||||||
|
from text_generation_server.layers.mlp import MLPSpeculatorHead
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeHead(torch.nn.Module):
|
||||||
|
def __init__(self, lm_head, speculator):
|
||||||
|
super().__init__()
|
||||||
|
self.head = lm_head
|
||||||
|
self.speculator = speculator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
speculator = config.speculator
|
||||||
|
if speculator:
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
speculator_config = str(speculator_path / "config.json")
|
||||||
|
|
||||||
|
with open(speculator_config, "r") as f:
|
||||||
|
speculator_config = json.load(f)
|
||||||
|
|
||||||
|
config.speculator_config = speculator_config
|
||||||
|
try:
|
||||||
|
architecture = speculator_config["architectures"][0]
|
||||||
|
|
||||||
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||||
|
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
speculator = None
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
speculator = MedusaHeadV1.load(config, prefix, weights)
|
||||||
|
except Exception:
|
||||||
|
speculator = MedusaHeadV2(config, prefix, weights)
|
||||||
|
lm_head = None
|
||||||
|
else:
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
speculator = None
|
||||||
|
return SpeculativeHead(lm_head, speculator)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if self.speculator is not None:
|
||||||
|
return self.speculator(input)
|
||||||
|
|
||||||
|
assert self.head is not None
|
||||||
|
logits = self.head(input)
|
||||||
|
return logits, None
|
@ -0,0 +1,244 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Iterable, List
|
||||||
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
|
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
|
class LayerConcat(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Apply multiple layers to the input and concatenate their
|
||||||
|
outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
|
||||||
|
"""
|
||||||
|
`dim` is the dimension along which layer outputs are concatenated.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.layers = layers
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
outputs = [layer(x) for layer in self.layers]
|
||||||
|
return torch.cat(outputs, self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class SuperLayer(torch.nn.Module):
|
||||||
|
def __init__(self, linear):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelHead(SuperLayer):
|
||||||
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
|
super().__init__(linear)
|
||||||
|
self.process_group = process_group
|
||||||
|
self.should_gather = should_gather
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
if config.quantize == "exl2":
|
||||||
|
try:
|
||||||
|
# If the piece and LM head embeddings are shared, we have
|
||||||
|
# non-quantized weights...
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
except Exception:
|
||||||
|
# ...otherwise they are quantized.
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
should_gather = weights.process_group.size() > 1
|
||||||
|
elif weights.process_group.size() > 1:
|
||||||
|
try:
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
should_gather = True
|
||||||
|
except AssertionError:
|
||||||
|
# If the vocab size is not divisible by number of shards
|
||||||
|
# just load the entire thing.
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
else:
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
|
||||||
|
return TensorParallelHead(
|
||||||
|
get_linear(weight, bias=None),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
should_gather=should_gather,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.should_gather:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
world_out = input.new_empty(1, out_dim * world_size)
|
||||||
|
local_out = input.new_empty(1, out_dim)
|
||||||
|
gather_input = local_out
|
||||||
|
else:
|
||||||
|
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||||
|
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||||
|
local_out = gather_input.T
|
||||||
|
|
||||||
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
return world_out
|
||||||
|
return world_out.T
|
||||||
|
|
||||||
|
output = super().forward(input)
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
return world_output
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
@classmethod
|
||||||
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_gate_up(prefix)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_qkv(
|
||||||
|
cls,
|
||||||
|
config,
|
||||||
|
prefix: str,
|
||||||
|
weights,
|
||||||
|
bias: bool,
|
||||||
|
num_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
|
prefix,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
if bias:
|
||||||
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||||
|
if config.quantize == "exl2":
|
||||||
|
linears = []
|
||||||
|
for prefix in prefixes:
|
||||||
|
weight = weights.get_weights_col(prefix)
|
||||||
|
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||||
|
linears.append(get_linear(weight, b))
|
||||||
|
linear = LayerConcat(linears)
|
||||||
|
else:
|
||||||
|
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||||
|
if bias:
|
||||||
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||||
|
bias = torch.cat(b, dim=dim)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelRowLinear(SuperLayer):
|
||||||
|
def __init__(self, linear, process_group):
|
||||||
|
super().__init__(linear)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
|
if bias and weights.process_group.rank() == 0:
|
||||||
|
# Rank is only on the first rank process
|
||||||
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(
|
||||||
|
get_linear(weight, bias),
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
|
out = super().forward(input)
|
||||||
|
if self.process_group.size() > 1 and reduce:
|
||||||
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
|
super().__init__()
|
||||||
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
block_size = (num_embeddings + world_size - 1) // world_size
|
||||||
|
self.min_id = rank * block_size
|
||||||
|
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||||
|
self.null_idx = weight.shape[
|
||||||
|
0
|
||||||
|
] # Usually block_size, might be less in non even vocab_size.
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.reduce = reduce
|
||||||
|
|
||||||
|
"""Additional 0 entry used for masking"""
|
||||||
|
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||||
|
# translate for [0, self.max_id - self.min_id[
|
||||||
|
input = torch.where(
|
||||||
|
(self.min_id > input) | (input >= self.max_id),
|
||||||
|
self.null_idx,
|
||||||
|
input - self.min_id,
|
||||||
|
)
|
||||||
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
|
if self.reduce and self.process_group.size() > 1:
|
||||||
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
return out
|
994
backends/gaudi/server/text_generation_server/models/__init__.py
Normal file
994
backends/gaudi/server/text_generation_server/models/__init__.py
Normal file
@ -0,0 +1,994 @@
|
|||||||
|
# ruff: noqa: F821
|
||||||
|
# the above line disables the `undefined-name` rule for the model type variables
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.models.auto import modeling_auto
|
||||||
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
import enum
|
||||||
|
|
||||||
|
# Needed to properly setup habana_frameworks
|
||||||
|
|
||||||
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
|
from text_generation_server.models.bloom import BLOOM
|
||||||
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
|
PhiMoEConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
AdapterParameters,
|
||||||
|
build_layer_weight_lookup,
|
||||||
|
load_and_merge_adapters,
|
||||||
|
AdapterInfo,
|
||||||
|
)
|
||||||
|
from text_generation_server.adapters.lora import LoraWeights
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Model",
|
||||||
|
"CausalLM",
|
||||||
|
"Seq2SeqLM",
|
||||||
|
"get_model_with_lora_adapters",
|
||||||
|
]
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
if ATTENTION == "paged":
|
||||||
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
|
FlashDeepseekV2ForCausalLM,
|
||||||
|
DeepseekV2Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||||
|
FlashDeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
|
FlashCohereForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
FlashGemmaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
|
FlashGemma2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.pali_gemma import (
|
||||||
|
PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||||
|
FlashMllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llava_next import (
|
||||||
|
FlashLlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
|
FlashSantacoderForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
|
FlashGPT2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
|
||||||
|
FlashGPTJForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||||
|
Qwen2_5VLForConditionalGeneration,
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
|
except ImportError as e:
|
||||||
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
__all__.append(FlashCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
DEEPSEEK_V2 = {
|
||||||
|
"type": "deepseek_v2",
|
||||||
|
"name": "Deepseek V2",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
|
}
|
||||||
|
DEEPSEEK_V3 = {
|
||||||
|
"type": "deepseek_v3",
|
||||||
|
"name": "Deepseek V3",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||||
|
}
|
||||||
|
IDEFICS2 = {
|
||||||
|
"type": "idefics2",
|
||||||
|
"name": "Idefics 2",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
IDEFICS3 = {
|
||||||
|
"type": "idefics3",
|
||||||
|
"name": "Idefics 3",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAVA_NEXT = {
|
||||||
|
"type": "llava_next",
|
||||||
|
"name": "Llava Next (1.6)",
|
||||||
|
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
LLAMA = {
|
||||||
|
"type": "llama",
|
||||||
|
"name": "Llama",
|
||||||
|
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||||
|
}
|
||||||
|
PHI3 = {
|
||||||
|
"type": "phi3",
|
||||||
|
"name": "Phi 3",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
}
|
||||||
|
GRANITE = {
|
||||||
|
"type": "granite",
|
||||||
|
"name": "Granite",
|
||||||
|
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
}
|
||||||
|
GEMMA = {
|
||||||
|
"type": "gemma",
|
||||||
|
"name": "Gemma",
|
||||||
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
|
GEMMA2 = {
|
||||||
|
"type": "gemma2",
|
||||||
|
"name": "Gemma2",
|
||||||
|
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||||
|
}
|
||||||
|
COHERE = {
|
||||||
|
"type": "cohere",
|
||||||
|
"name": "Cohere",
|
||||||
|
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||||
|
}
|
||||||
|
DBRX = {
|
||||||
|
"type": "dbrx",
|
||||||
|
"name": "Dbrx",
|
||||||
|
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||||
|
}
|
||||||
|
MAMBA = {
|
||||||
|
"type": "mamba",
|
||||||
|
"name": "Mamba",
|
||||||
|
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||||
|
}
|
||||||
|
MISTRAL = {
|
||||||
|
"type": "mistral",
|
||||||
|
"name": "Mistral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
}
|
||||||
|
MIXTRAL = {
|
||||||
|
"type": "mixtral",
|
||||||
|
"name": "Mixtral",
|
||||||
|
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||||
|
}
|
||||||
|
GPT_BIGCODE = {
|
||||||
|
"type": "gpt_bigcode",
|
||||||
|
"name": "Gpt Bigcode",
|
||||||
|
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||||
|
}
|
||||||
|
PHI = {
|
||||||
|
"type": "phi",
|
||||||
|
"name": "Phi",
|
||||||
|
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||||
|
}
|
||||||
|
PHI_MOE = {
|
||||||
|
"type": "phimoe",
|
||||||
|
"name": "PhiMoe",
|
||||||
|
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
}
|
||||||
|
BAICHUAN = {
|
||||||
|
"type": "baichuan",
|
||||||
|
"name": "Baichuan",
|
||||||
|
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
}
|
||||||
|
FALCON = {
|
||||||
|
"type": "falcon",
|
||||||
|
"name": "Falcon",
|
||||||
|
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||||
|
}
|
||||||
|
STARCODER2 = {
|
||||||
|
"type": "starcoder2",
|
||||||
|
"name": "StarCoder 2",
|
||||||
|
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||||
|
}
|
||||||
|
QWEN2 = {
|
||||||
|
"type": "qwen2",
|
||||||
|
"name": "Qwen 2",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
|
||||||
|
}
|
||||||
|
QWEN2_VL = {
|
||||||
|
"type": "qwen2_vl",
|
||||||
|
"name": "Qwen 2 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||||
|
}
|
||||||
|
QWEN2_5_VL = {
|
||||||
|
"type": "qwen2_5_vl",
|
||||||
|
"name": "Qwen 2.5 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
|
||||||
|
}
|
||||||
|
GALACTICA = {
|
||||||
|
"type": "galactica",
|
||||||
|
"name": "Galactica",
|
||||||
|
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||||
|
}
|
||||||
|
SANTACODER = {
|
||||||
|
"type": "santacoder",
|
||||||
|
"name": "SantaCoder",
|
||||||
|
"url": "https://huggingface.co/bigcode/santacoder",
|
||||||
|
}
|
||||||
|
GPT2 = {
|
||||||
|
"type": "gpt2",
|
||||||
|
"name": "Gpt2",
|
||||||
|
"url": "https://huggingface.co/openai-community/gpt2",
|
||||||
|
}
|
||||||
|
GPT_NEOX = {
|
||||||
|
"type": "gpt_neox",
|
||||||
|
"name": "Gpt Neox",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||||
|
}
|
||||||
|
GPTJ = {
|
||||||
|
"type": "gptj",
|
||||||
|
"name": "Gptj",
|
||||||
|
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||||
|
}
|
||||||
|
MLLAMA = {
|
||||||
|
"type": "mllama",
|
||||||
|
"name": "Mllama",
|
||||||
|
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__GLOBALS = locals()
|
||||||
|
for data in ModelType:
|
||||||
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
|
|
||||||
|
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
|
||||||
|
# Disable gradients
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapter_ids: Optional[List[str]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
) -> Model:
|
||||||
|
global FLASH_ATTENTION
|
||||||
|
|
||||||
|
if speculate is not None:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(0)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
speculator = None
|
||||||
|
if "medusa_num_heads" in config_dict:
|
||||||
|
medusa_model_id = model_id
|
||||||
|
medusa_revision = revision
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_medusa:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_medusa)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
is_local = Path(medusa_model_id).exists()
|
||||||
|
if not is_local:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
medusa_model_id, revision=medusa_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
medusa_model_id,
|
||||||
|
revision=medusa_revision,
|
||||||
|
filename="medusa_lm_head.safetensors",
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(medusa_config).parent,
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
speculator = {
|
||||||
|
"path": Path(medusa_model_id),
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
|
||||||
|
method = "medusa"
|
||||||
|
elif model_type == "mlp_speculator":
|
||||||
|
mlp_model_id = model_id
|
||||||
|
mlp_revision = revision
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_mlp = config_dict["n_predict"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_mlp:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_mlp)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
is_local = Path(mlp_model_id).exists()
|
||||||
|
extension = ".safetensors"
|
||||||
|
if not is_local:
|
||||||
|
mlp_speculator_config = hf_hub_download(
|
||||||
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
api = HfApi()
|
||||||
|
info = api.model_info(mlp_model_id, revision=mlp_revision)
|
||||||
|
filenames = [
|
||||||
|
s.rfilename
|
||||||
|
for s in info.siblings
|
||||||
|
if s.rfilename.endswith(extension)
|
||||||
|
and len(s.rfilename.split("/")) == 1
|
||||||
|
and "arguments" not in s.rfilename
|
||||||
|
and "args" not in s.rfilename
|
||||||
|
and "training" not in s.rfilename
|
||||||
|
]
|
||||||
|
for filename in filenames:
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
speculator_dir_path = Path(mlp_speculator_config).parent
|
||||||
|
# if these are downloaded, they get converted to safetensors
|
||||||
|
filenames.extend(
|
||||||
|
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(mlp_speculator_config).parent,
|
||||||
|
"model_paths": filenames,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
speculator = Path(mlp_model_id)
|
||||||
|
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
|
||||||
|
speculator = {"path": speculator, "model_paths": filenames}
|
||||||
|
method = "mlp_speculator"
|
||||||
|
else:
|
||||||
|
method = "n-gram"
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
|
if speculate > 0:
|
||||||
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
if model_type == DEEPSEEK_V2:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV2Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif model_type == DEEPSEEK_V3:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV3Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
model_type == GPT_BIGCODE
|
||||||
|
or model_type == GPT2
|
||||||
|
and model_id.startswith("bigcode/")
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
|
)
|
||||||
|
elif model_type == GPT2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPTJ:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTJForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GPT_NEOX:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
elif model_type == PHI:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == PHI_MOE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
config_class=PhiMoEConfig,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == BAICHUAN:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == COHERE:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == DBRX:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
|
||||||
|
and not sharded
|
||||||
|
and not config_dict.get("alibi", False)
|
||||||
|
):
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
|
)
|
||||||
|
elif model_type == MISTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMistralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == MIXTRAL:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == STARCODER2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == QWEN2_5_VL:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=Qwen2_5_VLConfig,
|
||||||
|
processor_class=Qwen2_5_VLProcessor,
|
||||||
|
)
|
||||||
|
elif model_type == MLLAMA:
|
||||||
|
return FlashMllamaCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashMllamaForConditionalGeneration,
|
||||||
|
batch_class=FlashMllamaCausalLMBatch,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS2:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
|
elif model_type == IDEFICS3:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
|
elif model_type == PALIGEMMA:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
elif model_type == LLAVA_NEXT:
|
||||||
|
return FlashVlmCausalLM(
|
||||||
|
model_class=FlashLlavaNextForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
|
LlavaNextForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
|
if SDP_ON_BF16 == 1:
|
||||||
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
|
if model_type == "gpt_bigcode":
|
||||||
|
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
|
if model_type == "bloom":
|
||||||
|
return BLOOM(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "llava_next":
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=None,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "mllama":
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_class=MllamaForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=None,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
|
||||||
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
|
def get_model_with_lora_adapters(
|
||||||
|
model_id: str,
|
||||||
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
max_input_tokens: int,
|
||||||
|
adapter_to_index: Dict[str, int],
|
||||||
|
):
|
||||||
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||||
|
model = get_model(
|
||||||
|
model_id,
|
||||||
|
lora_adapter_ids,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
max_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lora_adapters) > 0:
|
||||||
|
target_to_layer = build_layer_weight_lookup(model.model)
|
||||||
|
|
||||||
|
for index, adapter in enumerate(lora_adapters):
|
||||||
|
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
|
||||||
|
# At the moment, we only support loading a single adapter into the model, but we keep the
|
||||||
|
# AdapterParameters object for easier extension in the future.
|
||||||
|
adapter_parameters = AdapterParameters(
|
||||||
|
adapter_info=[adapter],
|
||||||
|
# when merging multiple adapters we can weight them differently
|
||||||
|
# if this is not set, all adapters will be weighted equally
|
||||||
|
# see: text_generation_server.utils.merges.strategies for impl
|
||||||
|
weights=None,
|
||||||
|
merge_strategy=0,
|
||||||
|
density=1.0,
|
||||||
|
majority_sign_method=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index = index + 1
|
||||||
|
adapter_to_index[adapter.id] = adapter_index
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
||||||
|
)
|
||||||
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
||||||
|
(
|
||||||
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_and_merge_adapters(
|
||||||
|
model.model_id,
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_index,
|
||||||
|
weight_names,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names = adapter_weight_names.copy()
|
||||||
|
|
||||||
|
adapter_layers = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"qkv_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer_name in adapter_layers:
|
||||||
|
nlayers = (
|
||||||
|
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
||||||
|
)
|
||||||
|
adapter_weights = LoraWeights.prepare_weights(
|
||||||
|
config=adapter_config,
|
||||||
|
module_map=module_map,
|
||||||
|
layer_type=layer_name,
|
||||||
|
unused_weight_names=unused_weight_names,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dtype=model.dtype,
|
||||||
|
world_size=model.world_size,
|
||||||
|
process_group=model.process_group,
|
||||||
|
target_to_layer=target_to_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_weights is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model.layer_to_adapter_weights[layer_name].add_adapter(
|
||||||
|
adapter_index, adapter_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(unused_weight_names) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_tokenizer is not None:
|
||||||
|
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||||
|
|
||||||
|
model.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
|
return model
|
52
backends/gaudi/server/text_generation_server/models/bloom.py
Normal file
52
backends/gaudi/server/text_generation_server/models/bloom.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "CausalLMBatch":
|
||||||
|
batch = super().from_pb(
|
||||||
|
pb=pb,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
batch.keys_head_dim_last = False
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class BLOOM(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(BLOOM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
|
return BloomCausalLMBatch
|
1426
backends/gaudi/server/text_generation_server/models/causal_lm.py
Normal file
1426
backends/gaudi/server/text_generation_server/models/causal_lm.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user