mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into cpu_perf
This commit is contained in:
commit
fb4d2080af
@ -4,3 +4,4 @@ server/transformers
|
||||
server/flash-attention
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
Dockerfile*
|
||||
|
32
.github/workflows/build.yaml
vendored
32
.github/workflows/build.yaml
vendored
@ -21,9 +21,11 @@ jobs:
|
||||
build-and-push:
|
||||
outputs:
|
||||
docker_image: ${{ steps.final.outputs.docker_image }}
|
||||
docker_volume: ${{ steps.final.outputs.docker_volume }}
|
||||
docker_devices: ${{ steps.final.outputs.docker_devices }}
|
||||
runs_on: ${{ steps.final.outputs.runs_on }}
|
||||
label: ${{ steps.final.outputs.label }}
|
||||
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@ -44,32 +46,39 @@ jobs:
|
||||
cuda)
|
||||
export dockerfile="Dockerfile"
|
||||
export label_extension=""
|
||||
export docker_volume="/mnt/cache"
|
||||
export docker_devices=""
|
||||
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||
export platform=""
|
||||
export extra_pytest=""
|
||||
;;
|
||||
rocm)
|
||||
export dockerfile="Dockerfile_amd"
|
||||
export label_extension="-rocm"
|
||||
export docker_devices="/dev/kfd,/dev/dri"
|
||||
# TODO Re-enable when they pass.
|
||||
# export runs_on="amd-gpu-tgi"
|
||||
export runs_on="ubuntu-latest"
|
||||
export docker_volume="/mnt"
|
||||
export runs_on="amd-gpu-runners"
|
||||
export platform=""
|
||||
export extra_pytest="-k test_flash_gemma_gptq_load"
|
||||
;;
|
||||
intel-xpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel-xpu"
|
||||
export docker_devices=""
|
||||
export docker_volume="/mnt/cache"
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform="xpu"
|
||||
export extra_pytest=""
|
||||
;;
|
||||
intel-cpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel-cpu"
|
||||
export docker_devices=""
|
||||
export runs_on="ubuntu-latest"
|
||||
export docker_devices="none"
|
||||
export docker_volume="/mnt/cache"
|
||||
# export runs_on="ubuntu-latest"
|
||||
export runs_on="aws-highmemory-32-plus-priv"
|
||||
export platform="cpu"
|
||||
export extra_pytest="-k test_flash_gemma_simple"
|
||||
;;
|
||||
esac
|
||||
echo $dockerfile
|
||||
@ -81,8 +90,10 @@ jobs:
|
||||
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
||||
echo "LABEL=${label_extension}" >> $GITHUB_ENV
|
||||
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
||||
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
|
||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
||||
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@ -157,16 +168,18 @@ jobs:
|
||||
run: |
|
||||
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
|
||||
integration_tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: build-and-push
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||
steps:
|
||||
@ -177,15 +190,16 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install
|
||||
run: |
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
export DOCKER_VOLUME=/mnt/cache
|
||||
export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
|
||||
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
echo $DOCKER_IMAGE
|
||||
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
||||
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
||||
|
196
Cargo.lock
generated
196
Cargo.lock
generated
@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -205,9 +205,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.3.0"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "av1-grain"
|
||||
@ -316,12 +316,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.6"
|
||||
version = "0.7.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
|
||||
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.4.4",
|
||||
"axum-core 0.4.5",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
@ -367,9 +367,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
|
||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
@ -392,7 +392,7 @@ version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
|
||||
dependencies = [
|
||||
"axum 0.7.6",
|
||||
"axum 0.7.7",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
@ -456,7 +456,7 @@ dependencies = [
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
"which",
|
||||
]
|
||||
|
||||
@ -605,9 +605,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.21"
|
||||
version = "1.1.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
|
||||
checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
@ -704,7 +704,7 @@ dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -971,7 +971,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"scratch",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1012,7 +1012,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1053,7 +1053,7 @@ dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645"
|
||||
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
|
||||
dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.33"
|
||||
version = "1.0.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253"
|
||||
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide 0.8.0",
|
||||
@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
|
||||
dependencies = [
|
||||
"adler",
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
|
||||
dependencies = [
|
||||
"adler2",
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2599,9 +2599,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
version = "1.20.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
|
||||
dependencies = [
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2808,7 +2811,7 @@ dependencies = [
|
||||
"glob",
|
||||
"once_cell",
|
||||
"opentelemetry 0.21.0",
|
||||
"ordered-float 4.2.2",
|
||||
"ordered-float 4.3.0",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"thiserror",
|
||||
@ -2828,7 +2831,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"once_cell",
|
||||
"opentelemetry 0.23.0",
|
||||
"ordered-float 4.2.2",
|
||||
"ordered-float 4.3.0",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"thiserror",
|
||||
@ -2851,9 +2854,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "4.2.2"
|
||||
version = "4.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6"
|
||||
checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2988,22 +2991,22 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.17.13"
|
||||
version = "0.17.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1"
|
||||
checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"crc32fast",
|
||||
"fdeflate",
|
||||
"flate2",
|
||||
"miniz_oxide 0.7.4",
|
||||
"miniz_oxide 0.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.8.0"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
|
||||
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3119,7 +3122,7 @@ dependencies = [
|
||||
"prost 0.12.6",
|
||||
"prost-types",
|
||||
"regex",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
@ -3146,7 +3149,7 @@ dependencies = [
|
||||
"itertools 0.12.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3205,7 +3208,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3218,7 +3221,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3402,9 +3405,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.5"
|
||||
version = "0.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
|
||||
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
]
|
||||
@ -3422,14 +3425,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.6"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
||||
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata 0.4.7",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-automata 0.4.8",
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3443,13 +3446,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.7"
|
||||
version = "0.4.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.4"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
@ -3563,7 +3566,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rust-embed-utils",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
@ -3686,9 +3689,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.8.0"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
|
||||
checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3840,9 +3843,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.7"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
|
||||
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
@ -4028,7 +4031,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4050,9 +4053,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.77"
|
||||
version = "2.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
|
||||
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.12.0"
|
||||
version = "3.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
|
||||
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"fastrand",
|
||||
@ -4174,7 +4177,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4197,7 +4200,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.18",
|
||||
@ -4217,7 +4220,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
@ -4235,7 +4238,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.18",
|
||||
"ctrlc",
|
||||
@ -4244,6 +4247,7 @@ dependencies = [
|
||||
"nix 0.28.0",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@ -4255,11 +4259,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.6",
|
||||
"axum 0.7.7",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"clap 4.5.18",
|
||||
@ -4304,11 +4308,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.6",
|
||||
"axum 0.7.7",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"clap 4.5.18",
|
||||
@ -4353,11 +4357,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.6",
|
||||
"axum 0.7.7",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"clap 4.5.18",
|
||||
@ -4428,7 +4432,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4533,7 +4537,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
@ -4566,7 +4570,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.8.4",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
@ -4612,7 +4616,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4771,7 +4775,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"prost-build",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4858,7 +4862,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -5151,7 +5155,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -5160,7 +5164,7 @@ version = "6.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
||||
dependencies = [
|
||||
"axum 0.7.6",
|
||||
"axum 0.7.7",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
@ -5189,7 +5193,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -5290,7 +5294,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@ -5324,7 +5328,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@ -5668,9 +5672,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.19"
|
||||
version = "0.6.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
|
||||
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
@ -5703,7 +5707,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -20,7 +20,7 @@ default-members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -39,7 +40,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
|
24
Dockerfile.nix
Normal file
24
Dockerfile.nix
Normal file
@ -0,0 +1,24 @@
|
||||
# Build the image and get out the docker file:
|
||||
#
|
||||
# docker build -t tgi-nix-builder -f Dockerfile.nix
|
||||
# docker run --log-driver=none tgi-nix-builder | docker load
|
||||
|
||||
FROM nixos/nix:2.18.8 AS builder
|
||||
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||
RUN nix profile install nixpkgs#cachix
|
||||
RUN cachix use text-generation-inference
|
||||
WORKDIR /root
|
||||
ADD . .
|
||||
RUN nix build .
|
||||
RUN mkdir /tmp/nix-store-closure
|
||||
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
|
||||
|
||||
FROM ubuntu:24.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy /nix/store
|
||||
COPY --from=builder /tmp/nix-store-closure /nix/store
|
||||
COPY --from=builder /root/result /app
|
||||
RUN ldconfig
|
||||
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
|
180
Dockerfile_amd
180
Dockerfile_amd
@ -1,5 +1,5 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -31,6 +31,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -38,10 +39,10 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
@ -50,33 +51,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
git \
|
||||
make \
|
||||
libmsgpack-dev \
|
||||
libssl-dev \
|
||||
llvm-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
hipcub-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
hipfft-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3.11-dev && \
|
||||
python3.11-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
@ -100,41 +102,132 @@ RUN case ${TARGETPLATFORM} in \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
RUN conda install mkl=2021
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG COMMON_WORKDIR=/
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
# Install HIPBLASLt
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH="e6da924"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
|
||||
FROM scratch AS export_hipblaslt
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||
|
||||
# RCCL build stages
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||
RUN git clone https://github.com/ROCm/rccl \
|
||||
&& cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
FROM scratch AS export_rccl
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||
|
||||
# Triton build stages
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||
&& cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_triton
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||
|
||||
# # AMD-SMI build stages
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
FROM scratch AS export_amdsmi
|
||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||
|
||||
|
||||
FROM base as build_pytorch
|
||||
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
USE_ROCM="1" \
|
||||
BUILD_TEST="0" \
|
||||
USE_FBGEMM="0" \
|
||||
USE_NNPACK="0" \
|
||||
USE_QNNPACK="0" \
|
||||
USE_XNNPACK="0" \
|
||||
USE_FLASH_ATTENTION="1" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||
# Not yet in 2.5.0-rc1
|
||||
ARG PYTORCH_BRANCH="cedc116"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt --no-cache-dir \
|
||||
&& python tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch as export_pytorch
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
FROM base AS install_deps
|
||||
|
||||
FROM base AS kernel-builder
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Install hipblaslt
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
# RCCL needs to be installed twice
|
||||
&& dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y triton \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y amdsmi \
|
||||
&& pip install /install/*.whl;
|
||||
|
||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y torch torchvision \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
FROM install_deps AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
@ -174,7 +267,7 @@ COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base AS base-copy
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -224,6 +317,19 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
ARG PLATFORM=xpu
|
||||
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
@ -39,7 +40,7 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt --frozen
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
@ -52,7 +53,7 @@ ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
@ -111,6 +112,8 @@ ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||
ENV TORCH_LLM_ALLREDUCE=1
|
||||
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
@ -127,12 +130,22 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
ca-certificates \
|
||||
make \
|
||||
g++ \
|
||||
g++-12 \
|
||||
gcc-12 \
|
||||
git \
|
||||
wget \
|
||||
cmake \
|
||||
libnuma-dev
|
||||
|
||||
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
|
||||
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
|
||||
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
|
||||
RUN update-alternatives --set cc /usr/bin/gcc
|
||||
|
||||
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
|
||||
RUN update-alternatives --set c++ /usr/bin/g++
|
||||
|
||||
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
@ -164,16 +177,17 @@ RUN case ${TARGETPLATFORM} in \
|
||||
|
||||
RUN conda install -c conda-forge gperftools mkl
|
||||
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
|
||||
RUN pip install triton py-libnuma
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||
|
||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||
|
||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||
|
||||
|
@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
@ -100,6 +100,7 @@ pub async fn connect_backend(
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
|
@ -27,3 +27,6 @@ asyncio_mode = "auto"
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
||||
function: dict
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
type: str
|
||||
text: Optional[str] = None
|
||||
image_url: Any = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[Chunk]]] = None
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "2.3.1-dev0"
|
||||
"version": "2.3.2-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
@ -2114,12 +2114,18 @@
|
||||
"ToolType": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"default": null,
|
||||
"nullable": true
|
||||
"type": "string",
|
||||
"description": "Means the model can pick between generating a message or calling one or more tools.",
|
||||
"enum": [
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"description": "Means the model will not call any tool and instead generates a message.",
|
||||
"enum": [
|
||||
"none"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
@ -2131,13 +2137,10 @@
|
||||
"$ref": "#/components/schemas/FunctionName"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"default": null,
|
||||
"nullable": true
|
||||
}
|
||||
]
|
||||
],
|
||||
"description": "Controls which (if any) tool is called by the model.",
|
||||
"example": "auto"
|
||||
},
|
||||
"Url": {
|
||||
"type": "object",
|
||||
|
@ -3,6 +3,8 @@
|
||||
title: Text Generation Inference
|
||||
- local: quicktour
|
||||
title: Quick Tour
|
||||
- local: supported_models
|
||||
title: Supported Models
|
||||
- local: installation_nvidia
|
||||
title: Using TGI with Nvidia GPUs
|
||||
- local: installation_amd
|
||||
@ -15,8 +17,7 @@
|
||||
title: Using TGI with Intel GPUs
|
||||
- local: installation
|
||||
title: Installation from source
|
||||
- local: supported_models
|
||||
title: Supported Models and Hardware
|
||||
|
||||
- local: architecture
|
||||
title: Internal Architecture
|
||||
- local: usage_statistics
|
||||
|
@ -10,7 +10,7 @@ This diagram shows well there are these separate components:
|
||||
|
||||
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
||||
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
||||
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||
|
||||
The router and the model server can be two different machines, they do not need to be deployed together.
|
||||
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
To specify model revision, use `adapter_id@revision`, as follows:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||
```
|
||||
|
||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||
|
||||
```bash
|
||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq
|
||||
```
|
||||
|
||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
||||
|
||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||
|
||||
## Custom PagedAttention
|
||||
|
||||
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||
|
||||
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||
|
||||
## Unsupported features
|
||||
|
||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,14 +11,13 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you want to serve gated or private models, which provide
|
||||
controlled access to sensitive or proprietary content, refer to
|
||||
If you want to serve gated or private models, please refer to
|
||||
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
||||
for detailed instructions.
|
||||
|
||||
@ -97,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -89,6 +89,15 @@ Options:
|
||||
[env: DTYPE=]
|
||||
[possible values: float16, bfloat16]
|
||||
|
||||
```
|
||||
## KV_CACHE_DTYPE
|
||||
```shell
|
||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
||||
|
||||
[env: KV_CACHE_DTYPE=]
|
||||
[possible values: fp8_e5m2]
|
||||
|
||||
```
|
||||
## TRUST_REMOTE_CODE
|
||||
```shell
|
||||
|
@ -1,9 +1,7 @@
|
||||
|
||||
# Supported Models and Hardware
|
||||
# Supported Models
|
||||
|
||||
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
|
||||
|
||||
## Supported Models
|
||||
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
|
||||
|
||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||
@ -20,6 +18,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
|
||||
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-1_5)
|
||||
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
||||
@ -34,6 +33,8 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
||||
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
||||
|
||||
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
24
flake.lock
24
flake.lock
@ -497,11 +497,11 @@
|
||||
"systems": "systems_7"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"lastModified": 1726560853,
|
||||
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -718,11 +718,11 @@
|
||||
},
|
||||
"nixpkgs_6": {
|
||||
"locked": {
|
||||
"lastModified": 1724915739,
|
||||
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
|
||||
"lastModified": 1727675176,
|
||||
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
|
||||
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -853,11 +853,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1726626348,
|
||||
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
||||
"lastModified": 1727836133,
|
||||
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
||||
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -978,11 +978,11 @@
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1727353315,
|
||||
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
|
||||
"lastModified": 1728381423,
|
||||
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "1d42c4125ebafb87707118168995675cc5050b9d",
|
||||
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
14
flake.nix
14
flake.nix
@ -37,6 +37,7 @@
|
||||
overlays = [
|
||||
rust-overlay.overlays.default
|
||||
tgi-nix.overlays.default
|
||||
(import nix/overlay.nix)
|
||||
];
|
||||
};
|
||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||
@ -141,7 +142,8 @@
|
||||
};
|
||||
};
|
||||
|
||||
packages.default = pkgs.writeShellApplication {
|
||||
packages = rec {
|
||||
default = pkgs.writeShellApplication {
|
||||
name = "text-generation-inference";
|
||||
runtimeInputs = [
|
||||
server
|
||||
@ -151,6 +153,16 @@
|
||||
${launcher}/bin/text-generation-launcher "$@"
|
||||
'';
|
||||
};
|
||||
|
||||
dockerImage = pkgs.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
};
|
||||
|
||||
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
stream = true;
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -336,6 +336,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
@ -375,6 +376,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
@ -434,6 +438,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
@ -456,6 +461,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
@ -484,6 +492,7 @@ def launcher(event_loop):
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
container.stop()
|
||||
container.remove()
|
||||
container.wait()
|
||||
except NotFound:
|
||||
pass
|
||||
@ -506,13 +515,28 @@ def launcher(event_loop):
|
||||
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||
|
||||
if DOCKER_DEVICES:
|
||||
devices = DOCKER_DEVICES.split(",")
|
||||
if DOCKER_DEVICES.lower() == "none":
|
||||
devices = []
|
||||
else:
|
||||
devices = DOCKER_DEVICES.strip().split(",")
|
||||
visible = os.getenv("ROCR_VISIBLE_DEVICES")
|
||||
if visible:
|
||||
env["ROCR_VISIBLE_DEVICES"] = visible
|
||||
device_requests = []
|
||||
if not devices:
|
||||
devices = None
|
||||
elif devices == ["nvidia.com/gpu=all"]:
|
||||
devices = None
|
||||
device_requests = [
|
||||
docker.types.DeviceRequest(
|
||||
driver="cdi",
|
||||
# count=gpu_count,
|
||||
device_ids=[f"nvidia.com/gpu={i}"],
|
||||
)
|
||||
for i in range(gpu_count)
|
||||
]
|
||||
else:
|
||||
devices = []
|
||||
devices = None
|
||||
device_requests = [
|
||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||
]
|
||||
@ -532,6 +556,7 @@ def launcher(event_loop):
|
||||
shm_size="1G",
|
||||
)
|
||||
|
||||
try:
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
|
||||
if not use_flash_attention:
|
||||
@ -546,7 +571,11 @@ def launcher(event_loop):
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
finally:
|
||||
try:
|
||||
container.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if DOCKER_IMAGE is not None:
|
||||
return docker_launcher
|
||||
@ -589,7 +618,6 @@ def generate_multi():
|
||||
max_new_tokens: int,
|
||||
seed: Optional[int] = None,
|
||||
) -> List[Response]:
|
||||
|
||||
import numpy as np
|
||||
|
||||
arange = np.arange(len(prompts))
|
||||
|
@ -0,0 +1,104 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.079956055,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.2763672,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37548828,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4628906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02885437,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.2565918,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0063438416,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3056641,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.6035156,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 3,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -22.96875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -10.71875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -2.6992188,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -4.8398438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 720,
|
||||
"logprob": -0.4411621,
|
||||
"special": false,
|
||||
"text": " \n"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -0.35864258,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 128001,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "<|end_of_text|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning? \n "
|
||||
}
|
@ -0,0 +1,418 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
||||
]
|
@ -0,0 +1,104 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.296875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.97216797,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.9658203,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44384766,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.50878906,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8876953,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15124512,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.16687012,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.8046875,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007205963,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.090026855,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030670166,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -13.921875,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -11.2265625,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -2.3886719,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -4.7109375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 17504,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.5151367,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 13253,
|
||||
"logprob": -1.3359375,
|
||||
"special": false,
|
||||
"text": " Machine"
|
||||
},
|
||||
{
|
||||
"id": 17504,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||
}
|
@ -0,0 +1,418 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.296875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.97216797,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.9658203,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44384766,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.50878906,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8876953,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15136719,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030273438,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1665039,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.1776123,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.8076172,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.090148926,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030670166,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1824,
|
||||
"logprob": -12.34375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.96728516,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 3534,
|
||||
"logprob": -10.1796875,
|
||||
"text": "deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.97265625,
|
||||
"text": "learning"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.44189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.51220703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87402344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23229,
|
||||
"logprob": -0.15039062,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.030288696,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.1652832,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.17858887,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 19804,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.007183075,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5599,
|
||||
"logprob": -0.08880615,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 5168,
|
||||
"logprob": -0.0030612946,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||
}
|
||||
]
|
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8808594,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37280273,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.26098633,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017137527,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2695312,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9238281,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48828125,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.34838867,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13940,
|
||||
"logprob": -0.38916016,
|
||||
"special": false,
|
||||
"text": "``"
|
||||
},
|
||||
{
|
||||
"id": 28832,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "`"
|
||||
},
|
||||
{
|
||||
"id": 3371,
|
||||
"logprob": -1.2529297,
|
||||
"special": false,
|
||||
"text": "json"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28751,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "{"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 2287,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 3134,
|
||||
"logprob": -0.640625,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request\n```json\n{\n \"request"
|
||||
}
|
@ -0,0 +1,358 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -11.0078125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -13.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.68847656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28771,
|
||||
"logprob": -1.9394531,
|
||||
"special": false,
|
||||
"text": "#"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -2.8828125,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.37329102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2602539,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0017185211,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1064,
|
||||
"logprob": -2.2753906,
|
||||
"special": false,
|
||||
"text": "##"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -1.9316406,
|
||||
"special": false,
|
||||
"text": " Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -0.48217773,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||
}
|
||||
]
|
@ -0,0 +1,109 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7133789,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9296875,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.048919678,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8105469,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017028809,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0027313232,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0623207e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17578125,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011539459,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027680397,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": null,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -6.4960938,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -5.1484375,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -4.0351562,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -5.2265625,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 10994,
|
||||
"logprob": -1.1542969,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 29991,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "!"
|
||||
},
|
||||
{
|
||||
"id": 739,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " It"
|
||||
},
|
||||
{
|
||||
"id": 2444,
|
||||
"logprob": -0.42260742,
|
||||
"special": false,
|
||||
"text": " seems"
|
||||
},
|
||||
{
|
||||
"id": 366,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " you"
|
||||
},
|
||||
{
|
||||
"id": 29915,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
{
|
||||
"id": 276,
|
||||
"logprob": -0.9838867,
|
||||
"special": false,
|
||||
"text": "re"
|
||||
},
|
||||
{
|
||||
"id": 3211,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " address"
|
||||
},
|
||||
{
|
||||
"id": 292,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "ing"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.15124512,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||
}
|
@ -0,0 +1,438 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7133789,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9296875,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.048919678,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8105469,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017028809,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0028476715,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023971558,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17602539,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.000116467476,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027871132,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7128906,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.05053711,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0058594,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.84521484,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.018859863,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.002822876,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17126465,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.0001155138,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.47436523,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027036667,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.71484375,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.049346924,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.86328125,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.017196655,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0028438568,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.026558e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17602539,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011622906,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.48608398,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027894974,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1724,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -0.7192383,
|
||||
"text": "is"
|
||||
},
|
||||
{
|
||||
"id": 16030,
|
||||
"logprob": -13.9375,
|
||||
"text": "gradient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.050445557,
|
||||
"text": "descent"
|
||||
},
|
||||
{
|
||||
"id": 29973,
|
||||
"logprob": -3.0078125,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.8242188,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8276367,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25584,
|
||||
"logprob": -0.01727295,
|
||||
"special": false,
|
||||
"text": "Grad"
|
||||
},
|
||||
{
|
||||
"id": 993,
|
||||
"logprob": -0.0027542114,
|
||||
"special": false,
|
||||
"text": "ient"
|
||||
},
|
||||
{
|
||||
"id": 26815,
|
||||
"logprob": -0.023254395,
|
||||
"special": false,
|
||||
"text": " descent"
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"logprob": -2.0384789e-05,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 937,
|
||||
"logprob": -0.17126465,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2098,
|
||||
"logprob": -0.00011301041,
|
||||
"special": false,
|
||||
"text": "order"
|
||||
},
|
||||
{
|
||||
"id": 13883,
|
||||
"logprob": -0.48608398,
|
||||
"special": false,
|
||||
"text": " optimization"
|
||||
},
|
||||
{
|
||||
"id": 5687,
|
||||
"logprob": -0.00027894974,
|
||||
"special": false,
|
||||
"text": " algorithm"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||
}
|
||||
]
|
@ -0,0 +1,106 @@
|
||||
[
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a bustling city, a chicken named Cluck",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a world where even chickens could dream big,",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727773835,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
}
|
||||
]
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a bustling city, a chicken named Cluck",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727556016,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60
|
||||
}
|
||||
}
|
@ -1,38 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"content": "I am an AI assistant",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": null,
|
||||
"name": "notify_error"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1712852597,
|
||||
"created": 1728497062,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 496,
|
||||
"total_tokens": 535
|
||||
"completion_tokens": 23,
|
||||
"prompt_tokens": 604,
|
||||
"total_tokens": 627
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " assistant",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497531,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " fans",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1728497461,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "2.3.2-dev0-native",
|
||||
"usage": null
|
||||
}
|
@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||
async def test_flash_gemma_simple(flash_gemma, response_snapshot):
|
||||
response = await flash_gemma.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama(flash_llama, response_snapshot):
|
||||
async def test_flash_llama_simple(flash_llama, response_snapshot):
|
||||
response = await flash_llama.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||
with launcher(
|
||||
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
|
||||
await flash_llama_fp8_kv_cache_handle.health(300)
|
||||
return flash_llama_fp8_kv_cache_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_all_params(
|
||||
flash_llama_fp8_kv_cache, response_snapshot
|
||||
):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_load(
|
||||
flash_llama_fp8_kv_cache, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||
assert responses == response_snapshot
|
73
integration-tests/models/test_flash_mixtral_awq.py
Normal file
73
integration-tests/models/test_flash_mixtral_awq.py
Normal file
@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_mixtral_awq_handle(launcher):
|
||||
with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_mixtral_awq(flash_mixtral_awq_handle):
|
||||
await flash_mixtral_awq_handle.health(300)
|
||||
return flash_mixtral_awq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot):
|
||||
response = await flash_mixtral_awq.generate(
|
||||
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text == "\n\nDeep learning is a subset of machine learning"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot):
|
||||
response = await flash_mixtral_awq.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_awq_load(
|
||||
flash_mixtral_awq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].details.generated_tokens == 10
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== "\n\nDeep learning is a subset of machine learning"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_mixtral_gptq_handle(launcher):
|
||||
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
||||
await flash_mixtral_gptq_handle.health(300)
|
||||
return flash_mixtral_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||
response = await flash_mixtral_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||
response = await flash_mixtral_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_mixtral_gptq_load(
|
||||
flash_mixtral_gptq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_phi35_moe_handle(launcher):
|
||||
with launcher(
|
||||
"microsoft/Phi-3.5-MoE-instruct",
|
||||
num_shard=4,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_phi35_moe(flash_phi35_moe_handle):
|
||||
await flash_phi35_moe_handle.health(300)
|
||||
return flash_phi35_moe_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
||||
response = await flash_phi35_moe.generate(
|
||||
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "Gradient descent is a first-order optimization algorithm"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||
response = await flash_phi35_moe.generate(
|
||||
"What is gradient descent?\n\n",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].details.generated_tokens == 10
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== "Gradient descent is a first-order optimization algorithm"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
|
||||
assert responses == response_snapshot
|
105
integration-tests/models/test_mllama.py
Normal file
105
integration-tests/models/test_mllama.py
Normal file
@ -0,0 +1,105 @@
|
||||
import pytest
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mllama_handle(launcher):
|
||||
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def mllama(mllama_handle):
|
||||
await mllama_handle.health(300)
|
||||
return mllama_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_simpl(mllama, response_snapshot):
|
||||
# chicken = get_chicken()
|
||||
response = await mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert response.usage == {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 50,
|
||||
"total_tokens": 60,
|
||||
}
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "In a bustling city, a chicken named Cluck"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||
futures = [
|
||||
mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
responses = await asyncio.gather(*futures)
|
||||
|
||||
generated_texts = [response.choices[0].message.content for response in responses]
|
||||
|
||||
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
@ -4,7 +4,9 @@ import pytest
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_tools_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
num_shard=2,
|
||||
disable_grammar_support=False,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
@ -205,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
)
|
||||
|
||||
count = 0
|
||||
tool_calls_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||
last_response = response
|
||||
assert response.choices[0].delta.content is None
|
||||
|
||||
assert count == 48
|
||||
assert response == response_snapshot
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
|
||||
)
|
||||
assert count == 28
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -225,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.tool_calls is None
|
||||
assert responses.choices[0].message.content == "I am an AI assistant"
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
content_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
content_generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 5
|
||||
assert content_generated == "I am an AI assistant"
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content is None
|
||||
count = 0
|
||||
content_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
content_generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 62
|
||||
assert (
|
||||
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||
content_generated
|
||||
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||
)
|
||||
assert responses == response_snapshot
|
||||
assert last_response == response_snapshot
|
||||
|
@ -13,3 +13,6 @@ pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
docker = "^7"
|
||||
numpy = "^1.20"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -18,6 +18,7 @@ serde_json = "1.0.107"
|
||||
thiserror = "1.0.59"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
regex = "1.11.0"
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
|
@ -1,9 +1,4 @@
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
|
||||
LazyLock::new(get_cuda_capability);
|
||||
|
||||
fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||
pub fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||
use pyo3::prelude::*;
|
||||
|
||||
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||
|
@ -5,6 +5,7 @@ use hf_hub::{
|
||||
};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
@ -66,7 +67,7 @@ fn get_config(
|
||||
}
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = *gpu::COMPUTE_CAPABILITY;
|
||||
let compute_capability = gpu::get_cuda_capability();
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
@ -300,6 +301,22 @@ impl std::fmt::Display for Dtype {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum KVCacheDtype {
|
||||
#[clap(name = "fp8_e5m2")]
|
||||
Fp8e5m2,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for KVCacheDtype {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KVCacheDtype::Fp8e5m2 => {
|
||||
write!(f, "fp8_e5m2")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum RopeScaling {
|
||||
Linear,
|
||||
@ -401,6 +418,12 @@ struct Args {
|
||||
#[clap(long, env, value_enum)]
|
||||
dtype: Option<Dtype>,
|
||||
|
||||
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||
/// the only supported value is `fp8_e5m2` on CUDA.
|
||||
#[clap(long, env, value_enum)]
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
|
||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||
/// contributed in a newer revision.
|
||||
@ -669,6 +692,7 @@ fn shard_manager(
|
||||
quantize: Option<Quantization>,
|
||||
speculate: Option<usize>,
|
||||
dtype: Option<Dtype>,
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
trust_remote_code: bool,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
@ -742,6 +766,11 @@ fn shard_manager(
|
||||
shard_args.push(dtype.to_string())
|
||||
}
|
||||
|
||||
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||
shard_args.push("--kv-cache-dtype".to_string());
|
||||
shard_args.push(kv_cache_dtype.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision {
|
||||
shard_args.push("--revision".to_string());
|
||||
@ -915,6 +944,7 @@ fn shard_manager(
|
||||
}
|
||||
});
|
||||
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||
if LevelFilter::current() >= tracing::Level::DEBUG {
|
||||
thread::spawn(move || {
|
||||
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||
loop {
|
||||
@ -922,12 +952,11 @@ fn shard_manager(
|
||||
if let Ok(n) = stdin.read(&mut buffer) {
|
||||
if n > 0 {
|
||||
let _ = pstdin.write_all(&buffer[..n]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut ready = false;
|
||||
let start_time = Instant::now();
|
||||
@ -1302,6 +1331,7 @@ fn spawn_shards(
|
||||
let otlp_service_name = args.otlp_service_name.clone();
|
||||
let speculate = args.speculate;
|
||||
let dtype = args.dtype;
|
||||
let kv_cache_dtype = args.kv_cache_dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
@ -1320,6 +1350,7 @@ fn spawn_shards(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
rank,
|
||||
@ -1812,14 +1843,37 @@ fn main() -> Result<(), LauncherError> {
|
||||
if adapter.contains('=') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let adapter = adapter.trim();
|
||||
|
||||
// check if adapter has more than 1 '@'
|
||||
if adapter.matches('@').count() > 1 {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
|
||||
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||
if let Some(caps) = re.captures(adapter) {
|
||||
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||
let revision = caps.get(3).map(|m| m.as_str());
|
||||
|
||||
download_convert_model(
|
||||
adapter,
|
||||
None,
|
||||
adapter_id,
|
||||
revision,
|
||||
args.trust_remote_code,
|
||||
args.huggingface_hub_cache.as_deref(),
|
||||
args.weights_cache_override.as_deref(),
|
||||
running.clone(),
|
||||
)?;
|
||||
} else {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"Invalid LoRA adapter format: {}",
|
||||
adapter
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
23
nix/docker.nix
Normal file
23
nix/docker.nix
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
dockerTools,
|
||||
cacert,
|
||||
text-generation-inference,
|
||||
stream ? false,
|
||||
}:
|
||||
|
||||
let
|
||||
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
|
||||
in
|
||||
build {
|
||||
name = "tgi-docker";
|
||||
tag = "latest";
|
||||
config = {
|
||||
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
|
||||
Env = [
|
||||
"HF_HOME=/data"
|
||||
"PORT=80"
|
||||
];
|
||||
|
||||
};
|
||||
contents = [ cacert ];
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
{
|
||||
mkShell,
|
||||
black,
|
||||
isort,
|
||||
openssl,
|
||||
pkg-config,
|
||||
protobuf,
|
||||
@ -14,6 +16,8 @@
|
||||
mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
black
|
||||
isort
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
|
41
nix/overlay.nix
Normal file
41
nix/overlay.nix
Normal file
@ -0,0 +1,41 @@
|
||||
final: prev: {
|
||||
# You can use this overlay to temporarily override packages for
|
||||
# development. For permanent overrides, it's better to do this in
|
||||
# our package flake:
|
||||
#
|
||||
# https://github.com/huggingface/text-generation-inference-nix
|
||||
#
|
||||
# Note that overriding packages that are in the transitive closure
|
||||
# of many other packages (e.g. transformers) will require a large
|
||||
# rebuild.
|
||||
|
||||
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
|
||||
(
|
||||
python-self: python-super: with python-self; {
|
||||
# Python package override example:
|
||||
# transformers = python-super.transformers.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "huggingface";
|
||||
# repo = "transformers";
|
||||
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
|
||||
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
|
||||
# };
|
||||
# }
|
||||
# );
|
||||
}
|
||||
)
|
||||
];
|
||||
|
||||
# Non-python package override example:
|
||||
#
|
||||
# ripgrep = prev.ripgrep.overrideAttrs (
|
||||
# _: _: {
|
||||
# src = final.fetchFromGitHub {
|
||||
# owner = "BurntSushi";
|
||||
# repo = "ripgrep";
|
||||
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
|
||||
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
|
||||
# };
|
||||
# });
|
||||
}
|
@ -146,6 +146,7 @@ pub enum Config {
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Mllama,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
@ -159,6 +160,7 @@ pub enum Config {
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Phi3,
|
||||
PhiMoe,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Paligemma(Paligemma),
|
||||
|
@ -29,7 +29,7 @@ impl ChatTemplate {
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
tracing::debug!("Loading template: {:#?}", template_str);
|
||||
tracing::debug!("Loading template: {}", template_str);
|
||||
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
|
@ -355,6 +355,8 @@ pub enum InferError {
|
||||
MissingTemplateVariable(String),
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
#[error("Stream event serialization error")]
|
||||
StreamSerializationError(String),
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
@ -368,6 +370,7 @@ impl InferError {
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,32 +31,29 @@ impl ToolGrammar {
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the notify_error function to the tools
|
||||
let notify_error = Tool {
|
||||
// add the no_tool function to the tools
|
||||
let no_tool = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "notify_error".to_string(),
|
||||
description: Some("Notify an error or issue".to_string()),
|
||||
name: "no_tool".to_string(),
|
||||
description: Some("Open ened response with no specific tool selected".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
"description": "The response content",
|
||||
}
|
||||
},
|
||||
"required": ["error"]
|
||||
"required": ["content"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
tools.push(notify_error);
|
||||
tools.push(no_tool);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
ToolType::Function(function) => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools.clone(),
|
||||
|
@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
#[schema(example = "auto")]
|
||||
/// Controls which (if any) tool is called by the model.
|
||||
pub enum ToolType {
|
||||
/// Means the model can pick between generating a message or calling one or more tools.
|
||||
#[schema(rename = "auto")]
|
||||
OneOf,
|
||||
FunctionName(String),
|
||||
Function { function: FunctionName },
|
||||
/// Means the model will not call any tool and instead generates a message.
|
||||
#[schema(rename = "none")]
|
||||
NoTool,
|
||||
/// Forces the model to call a specific tool.
|
||||
#[schema(rename = "function")]
|
||||
Function(FunctionName),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||
@ -977,6 +983,7 @@ pub struct ToolChoice(pub Option<ToolType>);
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ToolTypeDeserializer {
|
||||
Null,
|
||||
String(String),
|
||||
ToolType(ToolType),
|
||||
}
|
||||
@ -984,10 +991,11 @@ enum ToolTypeDeserializer {
|
||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||
fn from(value: ToolTypeDeserializer) -> Self {
|
||||
match value {
|
||||
ToolTypeDeserializer::Null => ToolChoice(None),
|
||||
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||
_ => ToolChoice(Some(ToolType::FunctionName(s))),
|
||||
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
|
||||
},
|
||||
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
@ -452,12 +453,20 @@ async fn generate_stream(
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
while let Some(raw_event) = response_stream.next().await {
|
||||
yield Ok(raw_event.map_or_else(Event::from, |token| {
|
||||
Event::default()
|
||||
.json_data(token)
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
@ -466,9 +475,11 @@ async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
span: tracing::Span,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
) -> (
|
||||
HeaderMap,
|
||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
@ -500,12 +511,12 @@ async fn generate_stream_internal(
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
@ -535,8 +546,7 @@ async fn generate_stream_internal(
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
@ -600,9 +610,7 @@ async fn generate_stream_internal(
|
||||
details
|
||||
};
|
||||
|
||||
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
yield Ok(stream_token);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -610,7 +618,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -619,7 +627,7 @@ async fn generate_stream_internal(
|
||||
// yield error
|
||||
Err(err) => {
|
||||
error = true;
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
// Check if generation reached the end
|
||||
@ -628,7 +636,7 @@ async fn generate_stream_internal(
|
||||
let err = InferError::IncompleteGenerationStream;
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
yield Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -771,7 +779,24 @@ async fn completions(
|
||||
|
||||
// Create a future for each generate_stream_internal call.
|
||||
let generate_future = async move {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let (header_tx, header_rx) = oneshot::channel();
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
span_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
|
||||
while let Some(stream_token) = response_stream.next().await {
|
||||
match stream_token {
|
||||
Ok(stream_token) => {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -817,29 +842,22 @@ async fn completions(
|
||||
}),
|
||||
};
|
||||
|
||||
event
|
||||
let event = event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
.unwrap_or_else(|_e| Event::default());
|
||||
|
||||
yield Ok(event);
|
||||
}
|
||||
Err(err) => yield Ok(Event::from(err)),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (header_tx, header_rx) = oneshot::channel();
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (header_map, sse) = generate_stream_internal(
|
||||
infer_clone.clone(),
|
||||
compute_type_clone.clone(),
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// send and dont wait for response
|
||||
let _ = header_tx.send(header_map);
|
||||
let _ = header_tx.send(headers);
|
||||
|
||||
// pin an emit messages to the sse_tx
|
||||
let mut sse = Box::pin(sse);
|
||||
let mut sse = Box::pin(response_stream);
|
||||
while let Some(event) = sse.next().await {
|
||||
if sse_tx.send(event).is_err() {
|
||||
tracing::error!("Failed to send event. Receiver dropped.");
|
||||
@ -1072,6 +1090,84 @@ async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Content { skip_close_quote: bool },
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
stream_options: Option<StreamOptions>,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match &stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -1128,88 +1224,135 @@ async fn chat_completions(
|
||||
// static values that will be returned in all cases
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
Ok(regex) => regex,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to compile regex: {}", e),
|
||||
error_type: "regex".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
}
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
while let Some(result) = response_stream.next().await {
|
||||
if let Ok(stream_token) = result {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if using_tools {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
|
||||
let (usage, finish_reason) = match stream_token.details {
|
||||
Some(details) => {
|
||||
let usage = if stream_options
|
||||
.as_ref()
|
||||
.map(|s| s.include_usage)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
Some(Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(usage, Some(details.finish_reason.format(true)))
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
event
|
||||
.json_data(CompletionType::ChatCompletionChunk(
|
||||
ChatCompletionChunk::new(
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
usage,
|
||||
),
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
span,
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = response_stream.chain(futures::stream::once(async {
|
||||
Ok(Event::default().data("[DONE]"))
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
@ -1246,7 +1389,21 @@ async fn chat_completions(
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
|
||||
match name.as_str() {
|
||||
"no_tool" => {
|
||||
// parse the content message
|
||||
let content_message = arguments
|
||||
.get("content")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_string();
|
||||
(None, Some(content_message))
|
||||
}
|
||||
_ => {
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
@ -1257,6 +1414,8 @@ async fn chat_completions(
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
@ -1937,6 +2096,11 @@ async fn start(
|
||||
metrics::Unit::Count,
|
||||
"Maximum tokens for the current batch"
|
||||
);
|
||||
metrics::describe_gauge!(
|
||||
"tgi_batch_total_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Maximum amount of tokens in total."
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_max_new_tokens",
|
||||
metrics::Unit::Count,
|
||||
@ -2318,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(
|
||||
@ -2495,8 +2660,8 @@ mod tests {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
@ -567,6 +567,7 @@ fn image_tokens(
|
||||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => "<image>".to_string(),
|
||||
Mllama => "<|image|>".to_string(),
|
||||
Idefics2(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
@ -618,7 +619,7 @@ fn prepare_input(
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -1,5 +1,5 @@
|
||||
[toolchain]
|
||||
# Released on: June 13, 2024
|
||||
# https://releases.rs/docs/1.79.0/
|
||||
channel = "1.80.0"
|
||||
channel = "1.80.1"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||
build-flash-attention-v2-rocm:
|
||||
if [ ! -d 'flash-attention-v2' ]; then \
|
||||
pip install -U packaging ninja --no-cache-dir && \
|
||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
||||
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||
fi
|
||||
|
@ -1,5 +1,5 @@
|
||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
||||
build-vllm-rocm:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
||||
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||
fi
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
@ -1,5 +1,17 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = []
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
@ -13,6 +25,7 @@ setup(
|
||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
|
@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||
|
||||
extra_cflags = []
|
||||
if torch.version.hip:
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
||||
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": extra_cflags,
|
||||
"nvcc": extra_cuda_cflags,
|
||||
}
|
||||
|
||||
|
2583
server/poetry.lock
generated
2583
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
|
||||
opentelemetry-exporter-otlp = "^1.25.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.46b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.19.1"
|
||||
sentencepiece = "^0.2"
|
||||
tokenizers = "^0.20"
|
||||
huggingface-hub = "^0.23"
|
||||
transformers = "^4.43"
|
||||
transformers = "^4.45"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
@ -47,10 +47,10 @@ marlin-kernels = [
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
@ -82,3 +82,6 @@ requires = [
|
||||
"poetry-core>=1.0.0",
|
||||
]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,19 +1,19 @@
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -30,6 +30,10 @@ class Dtype(str, Enum):
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -38,6 +42,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
@ -97,6 +102,7 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -114,6 +120,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
@ -1,37 +1,40 @@
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
@ -39,5 +42,6 @@ __all__ = [
|
||||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
]
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
@ -65,5 +66,7 @@ else:
|
||||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
self.input_lengths = torch.clamp(self.input_lengths, max=max)
|
||||
return self
|
||||
|
@ -355,3 +355,11 @@ else:
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
page_size=page_size,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -119,7 +121,8 @@ def use_prefill_state(
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
@ -135,7 +138,8 @@ def use_prefill_state(
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=query_dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -200,7 +204,8 @@ def use_decode_state(
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
query_dtype: str = "float16",
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
@ -235,7 +240,9 @@ def use_decode_state(
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
q_data_type=query_dtype,
|
||||
data_type=dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -80,3 +80,12 @@ def paged_attention(
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
119
server/text_generation_server/layers/attention/kv_cache.py
Normal file
119
server/text_generation_server/layers/attention/kv_cache.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
|
||||
if dtype == torch.float8_e5m2 and (
|
||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||
)
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache[1]
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
|
||||
key_cache = self.kv_cache[0]
|
||||
value_cache = self.kv_cache[1]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
# TODO: add scale
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype == torch.float8_e5m2:
|
||||
# Torch index_put does not support float8_e5m2 yet, so
|
||||
# put as raw data instead.
|
||||
key_cache = key_cache.view(torch.uint8)
|
||||
value_cache = value_cache.view(torch.uint8)
|
||||
key = key.view(torch.uint8)
|
||||
value = value.view(torch.uint8)
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
@ -8,16 +9,28 @@ from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
_PARTITION_SIZE_V1V2 = 512
|
||||
_PARTITION_SIZE_CUSTOM = 256
|
||||
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
if use_rocm_custom_paged_attn:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
except ImportError as e:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Custom Paged Attention not available. Complete error: {e}",
|
||||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
@ -36,9 +49,7 @@ def reshape_and_cache(
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@ -48,8 +59,9 @@ def paged_attention(
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
@ -68,11 +80,31 @@ def paged_attention(
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
use_rocm_custom_paged_attn
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_s <= 32768
|
||||
)
|
||||
|
||||
if not use_custom:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||
else:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
input_lengths = seqlen.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
@ -81,9 +113,13 @@ def paged_attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
from vllm._C import ops
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
use_v1 = (
|
||||
max_s <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
and not use_custom
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
@ -115,6 +151,7 @@ def paged_attention(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if not use_custom:
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
@ -133,6 +170,25 @@ def paged_attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
paged_attention_custom(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -175,13 +231,14 @@ if ENGINE == "ck":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -191,46 +248,57 @@ if ENGINE == "ck":
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)[0]
|
||||
|
||||
elif ENGINE == "triton":
|
||||
from .flash_attn_triton import triton_attention
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
@ -238,3 +306,11 @@ elif ENGINE == "triton":
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
@ -1,12 +1,21 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
if ROCM_USE_SKINNY_GEMM:
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
raise ImportError(
|
||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
@ -48,6 +57,14 @@ class FastLinearROCm(torch.nn.Module):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.cu_count = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
self.use_skinny_gemm = (
|
||||
ROCM_USE_SKINNY_GEMM
|
||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
@ -61,7 +78,11 @@ class FastLinearROCm(torch.nn.Module):
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
||||
if (
|
||||
self.use_skinny_gemm
|
||||
and inp.dtype == torch.float16
|
||||
and inp.shape[-1] % 8 == 0
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
@ -69,13 +90,16 @@ class FastLinearROCm(torch.nn.Module):
|
||||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, k = weight.shape[0], inp_shape[1]
|
||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||
if m > 8 and n <= 4:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||
_custom_C.LLMM1(weight, inp, out, 8)
|
||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
@ -43,7 +43,7 @@ def can_use_gptq_marlin(
|
||||
and quant_method in {"awq", "gptq"}
|
||||
and bits in GPTQ_MARLIN_BITS
|
||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
# We only suppord asymmetric quantization for AWQ.
|
||||
# We only support asymmetric quantization for AWQ.
|
||||
and (sym or quant_method == "awq")
|
||||
)
|
||||
|
||||
@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
|
||||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
is_full_k = not (desc_act and sharded_infeatures)
|
||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||
|
||||
return GPTQMarlinWeight(
|
||||
qweight=repacked,
|
||||
|
@ -10,16 +10,24 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||
from text_generation_server.layers.moe.gptq_marlin import (
|
||||
GPTQMarlinSparseMoELayer,
|
||||
can_use_marlin_moe_gemm,
|
||||
)
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
UnquantizedWeight,
|
||||
)
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
from .fused_moe_rocm import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
|
||||
@ -202,12 +210,22 @@ class SparseMoELayer(nn.Module):
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||
cls = UnquantizedSparseMoELayer
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
# cls = GPTQMarlinSparseMoELayer
|
||||
elif isinstance(
|
||||
weights.loader, GPTQMarlinWeightsLoader
|
||||
) and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
):
|
||||
cls = GPTQMarlinSparseMoELayer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||
)
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"Using MoE layer wih fused gemm",
|
||||
)
|
||||
|
||||
self.moe = cls(
|
||||
@ -234,6 +252,12 @@ class SparseMoELayer(nn.Module):
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
)
|
||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
# Once we wire up GPTQ-Marlin MoE:
|
||||
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
or (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
52
server/text_generation_server/layers/moe/fused_moe_rocm.py
Normal file
52
server/text_generation_server/layers/moe/fused_moe_rocm.py
Normal file
@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
228
server/text_generation_server/layers/moe/gptq_marlin.py
Normal file
228
server/text_generation_server/layers/moe/gptq_marlin.py
Normal file
@ -0,0 +1,228 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinWeight,
|
||||
GPTQMarlinWeightsLoader,
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||
else:
|
||||
fused_marlin_moe = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
def can_use_marlin_moe_gemm(
|
||||
*,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and fused_marlin_moe is not None
|
||||
and has_sm_8_0
|
||||
and quantize in {"awq", "gptq"}
|
||||
and quant_method in {"awq", "gptq"}
|
||||
# We only support asymmetric quantization for AWQ.
|
||||
and (sym or quant_method == "awq")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinMoEWeight:
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
perm: torch.Tensor
|
||||
is_full_k: bool
|
||||
|
||||
|
||||
class GPTQMarlinSparseMoELayer(nn.Module):
|
||||
"""
|
||||
MoE layer that uses a fused GPTQ-Marlin kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported"
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
names=[gate_proj_name, up_proj_name],
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.down_proj = _load_expert_weights_row(
|
||||
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
||||
)
|
||||
|
||||
self.bits = weights.loader.bits
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_marlin_moe(
|
||||
hidden_states=x,
|
||||
w1=self.gate_up_proj.qweight,
|
||||
w2=self.down_proj.qweight,
|
||||
w1_scale=self.gate_up_proj.scales,
|
||||
w2_scale=self.down_proj.scales,
|
||||
w1_zeros=(
|
||||
self.gate_up_proj.qzeros
|
||||
if self.gate_up_proj.qzeros.numel() > 0
|
||||
else None
|
||||
),
|
||||
w2_zeros=(
|
||||
self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None
|
||||
),
|
||||
g_idx1=self.gate_up_proj.g_idx,
|
||||
g_idx2=self.down_proj.g_idx,
|
||||
sort_indices1=self.gate_up_proj.perm,
|
||||
sort_indices2=self.down_proj.perm,
|
||||
is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
num_bits=self.bits,
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
names: List[str],
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{name}" for name in names], 0
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _load_expert_weights_row(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_weights_row(
|
||||
f"{prefix}.{i}.{name}",
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _pack_weight(
|
||||
*,
|
||||
n_experts: int,
|
||||
expert: int,
|
||||
moe_weight: Optional[GPTQMarlinMoEWeight],
|
||||
weight: GPTQMarlinWeight,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
if moe_weight is None:
|
||||
qweight = torch.empty(
|
||||
(n_experts,) + weight.qweight.shape,
|
||||
dtype=weight.qweight.dtype,
|
||||
device=weight.qweight.device,
|
||||
)
|
||||
qzeros = torch.empty(
|
||||
(n_experts,) + weight.qzeros.shape,
|
||||
dtype=weight.qzeros.dtype,
|
||||
device=weight.qzeros.device,
|
||||
)
|
||||
scales = torch.empty(
|
||||
(n_experts,) + weight.scales.shape,
|
||||
dtype=weight.scales.dtype,
|
||||
device=weight.scales.device,
|
||||
)
|
||||
g_idx = torch.empty(
|
||||
(n_experts,) + weight.g_idx.shape,
|
||||
dtype=weight.g_idx.dtype,
|
||||
device=weight.g_idx.device,
|
||||
)
|
||||
perm = torch.empty(
|
||||
(n_experts,) + weight.perm.shape,
|
||||
dtype=weight.perm.dtype,
|
||||
device=weight.perm.device,
|
||||
)
|
||||
|
||||
moe_weight = GPTQMarlinMoEWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
perm=perm,
|
||||
is_full_k=weight.is_full_k,
|
||||
)
|
||||
|
||||
moe_weight.qweight[expert] = weight.qweight
|
||||
moe_weight.qzeros[expert] = weight.qzeros
|
||||
moe_weight.scales[expert] = weight.scales
|
||||
moe_weight.g_idx[expert] = weight.g_idx
|
||||
moe_weight.perm[expert] = weight.perm
|
||||
|
||||
return moe_weight
|
@ -6,7 +6,9 @@ import torch.nn as nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
||||
|
||||
@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "rocm":
|
||||
return fused_moe(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
|
@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
# if short_mscale and long_mscale are provided we need to scale the freqs
|
||||
# using the Phi3LongRoPEScaledRotaryEmbedding
|
||||
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
||||
short_mscale = rope_scaling["short_mscale"]
|
||||
long_mscale = rope_scaling["long_mscale"]
|
||||
return Phi3LongRoPEScaledRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
short_mscale=short_mscale,
|
||||
long_mscale=long_mscale,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
|
||||
return SuRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq: torch.Tensor,
|
||||
long_inv_freq: torch.Tensor,
|
||||
max_position_embeddings: int,
|
||||
short_mscale: float,
|
||||
long_mscale: float,
|
||||
original_max_position_embeddings: int,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
self.long_inv_freq = long_inv_freq
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.short_mscale = short_mscale
|
||||
self.long_mscale = long_mscale
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
|
||||
# cache
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||
|
||||
short_freqs = torch.outer(
|
||||
t[: self.original_max_position_embeddings],
|
||||
self.short_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
long_freqs = torch.outer(
|
||||
t[self.original_max_position_embeddings :],
|
||||
self.long_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
short_freqs = short_freqs * self.short_mscale
|
||||
long_freqs = long_freqs * self.long_mscale
|
||||
|
||||
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
||||
freqs[: self.original_max_position_embeddings] = short_freqs
|
||||
freqs[self.original_max_position_embeddings :] = long_freqs
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
@ -467,7 +539,6 @@ def apply_llama3_scaling(
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scaling_factor)
|
||||
else:
|
||||
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
|
@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||
PhiMoEConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
@ -73,6 +76,7 @@ FLASH_ATTENTION = True
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
FlashDeepseekV2ForCausalLM,
|
||||
DeepseekV2Config,
|
||||
@ -109,7 +113,11 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
@ -146,7 +154,7 @@ except ImportError as e:
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(IdeficsCausalLM)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
try:
|
||||
@ -237,6 +245,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Phi",
|
||||
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||
}
|
||||
PHI_MOE = {
|
||||
"type": "phimoe",
|
||||
"name": "PhiMoe",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||
}
|
||||
BAICHUAN = {
|
||||
"type": "baichuan",
|
||||
"name": "Baichuan",
|
||||
@ -308,6 +321,12 @@ class ModelType(enum.Enum):
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||
"multimodal": True,
|
||||
}
|
||||
MLLAMA = {
|
||||
"type": "mllama",
|
||||
"name": "Mllama",
|
||||
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
|
||||
__GLOBALS = locals()
|
||||
@ -323,6 +342,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
@ -384,6 +404,13 @@ def get_model(
|
||||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_dtype = dtype
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
|
||||
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
@ -544,6 +571,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
default_dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DeepseekV2Config,
|
||||
@ -598,6 +626,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
@ -649,6 +678,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -684,6 +714,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -722,6 +753,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=GPTNeoXConfig,
|
||||
@ -755,6 +787,31 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == PHI_MOE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
config_class=PhiMoEConfig,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -794,6 +851,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -817,6 +875,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -842,6 +901,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -868,6 +928,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -892,6 +953,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -922,6 +984,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
@ -940,6 +1003,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
@ -967,6 +1031,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -991,6 +1056,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -1015,6 +1081,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -1041,6 +1108,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -1085,7 +1153,7 @@ def get_model(
|
||||
)
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
return IdeficsCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -1095,6 +1163,22 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION:
|
||||
return MllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
@ -1104,6 +1188,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
@ -1121,6 +1206,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -1139,6 +1225,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
@ -1211,6 +1298,7 @@ def get_model_with_lora_adapters(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
@ -1224,6 +1312,7 @@ def get_model_with_lora_adapters(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
@ -517,11 +517,10 @@ class CausalLM(Model):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
elif SYSTEM == "ipex":
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
@ -593,8 +592,14 @@ class CausalLM(Model):
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
|
||||
device_count = 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
device_count = torch.cuda.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device("xpu")
|
||||
device_count = torch.xpu.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
@ -614,20 +619,12 @@ class CausalLM(Model):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
device_map=("auto" if device_count > 1 else None),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.device_count() == 1
|
||||
and quantize != "bitsandbytes"
|
||||
):
|
||||
model = model.cuda()
|
||||
if device_count == 1 and quantize != "bitsandbytes":
|
||||
model = model.to(device)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -28,7 +28,6 @@ if SYSTEM != "ipex":
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
@ -390,6 +389,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user