diff --git a/server/marlin/marlin_kernels/py.typed b/.devcontainer/Dockerfile.trtllm similarity index 100% rename from server/marlin/marlin_kernels/py.typed rename to .devcontainer/Dockerfile.trtllm diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..e69de29bb diff --git a/.dockerignore b/.dockerignore index c69283ec5..1c641e7a5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,5 @@ aml target server/transformers server/flash-attention +cmake-build-debug/ +cmake-build-release/ diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index e10b232c7..a768f263c 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -28,7 +28,7 @@ jobs: - name: Install router id: install-router - run: cargo install --path router/ + run: cargo install --path backends/v3/ - uses: actions/setup-node@v4 with: @@ -41,5 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | - npm install -g swagger-cli + npm install -g @redocly/cli python update_doc.py --check diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index cd9f19ba0..89d5bdf5c 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -27,8 +27,8 @@ jobs: concurrency: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - # TODO see with @Glegendre to get CPU runner here instead - runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + runs-on: + group: aws-highmemory-32-plus-priv permissions: contents: write packages: write @@ -49,7 +49,7 @@ jobs: export dockerfile="Dockerfile" export label_extension="" export docker_devices="" - export runs_on="nvidia-gpu" + export runs_on="aws-g6-12xlarge-plus-priv" ;; rocm) export dockerfile="Dockerfile_amd" @@ -75,13 +75,18 @@ jobs: echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV + echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v3 with: install: true - config-inline: | - [registry."docker.io"] - mirrors = ["registry.github-runners.huggingface.tech"] + buildkitd-config: /tmp/buildkitd.toml + - name: Login to internal Container Registry + uses: docker/login-action@v3 + with: + username: ${{ secrets.REGISTRY_USERNAME }} + password: ${{ secrets.REGISTRY_PASSWORD }} + registry: registry.internal.huggingface.tech - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -103,7 +108,7 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inference tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} # If main, release or tag @@ -115,7 +120,7 @@ jobs: flavor: | latest=auto images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inference ghcr.io/huggingface/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | @@ -141,7 +146,7 @@ jobs: - name: Final id: final run: | - echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" @@ -150,7 +155,8 @@ jobs: group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true needs: build-and-push - runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] + runs-on: + group: ${{ needs.build-and-push.outputs.runs_on }} if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index d62297e4a..5ca2854a9 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -10,6 +10,7 @@ on: paths: - ".github/workflows/build.yaml" - "integration-tests/**" + - "backends/**" - "server/**" - "proto/**" - "router/**" diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 0399e6d19..7336cb73d 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -23,7 +23,8 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - runs-on: [ self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci ] + runs-on: + group: aws-g5-12xlarge env: DOCKER_VOLUME: /cache steps: diff --git a/.gitignore b/.gitignore index e9ad1808c..0de8b8481 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ target router/tokenizer.json *__pycache__* +backends/v3/src/client/pb +backends/client/src/v2/pb +backends/client/src/v3/pb + # ROCm auto-generated files *.hip server/exllamav2_kernels/exllamav2_kernels/hip/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45bc07a54..6f5e685ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,11 @@ repos: - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 hooks: - - id: fmt - id: cargo-check + - id: fmt - id: clippy +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml new file mode 100644 index 000000000..382c9ab64 --- /dev/null +++ b/.redocly.lint-ignore.yaml @@ -0,0 +1,79 @@ +# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API. +# See https://redoc.ly/docs/cli/ for more information. +docs/openapi.json: + no-empty-servers: + - '#/openapi' + spec: + - >- + #/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum + - >- + #/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/grammar/nullable' + - >- + #/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum + - '#/components/schemas/GenerateResponse/properties/details/nullable' + - '#/components/schemas/StreamResponse/properties/details/nullable' + - '#/components/schemas/ChatRequest/properties/response_format/nullable' + - '#/components/schemas/ChatRequest/properties/tool_choice/nullable' + - '#/components/schemas/ToolChoice/nullable' + - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' + - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' + no-invalid-media-type-examples: + - '#/paths/~1/post/responses/422/content/application~1json/example' + - '#/paths/~1/post/responses/424/content/application~1json/example' + - '#/paths/~1/post/responses/429/content/application~1json/example' + - '#/paths/~1/post/responses/500/content/application~1json/example' + - '#/paths/~1generate/post/responses/422/content/application~1json/example' + - '#/paths/~1generate/post/responses/424/content/application~1json/example' + - '#/paths/~1generate/post/responses/429/content/application~1json/example' + - '#/paths/~1generate/post/responses/500/content/application~1json/example' + - >- + #/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example + - '#/paths/~1tokenize/post/responses/404/content/application~1json/example' + - >- + #/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/500/content/application~1json/example + operation-4xx-response: + - '#/paths/~1health/get/responses' + - '#/paths/~1info/get/responses' + - '#/paths/~1metrics/get/responses' + no-unused-components: + - '#/components/schemas/Completion' + security-defined: + - '#/paths/~1/post' + - '#/paths/~1generate/post' + - '#/paths/~1generate_stream/post' + - '#/paths/~1health/get' + - '#/paths/~1info/get' + - '#/paths/~1metrics/get' + - '#/paths/~1tokenize/post' + - '#/paths/~1v1~1chat~1completions/post' + - '#/paths/~1v1~1completions/post' diff --git a/Cargo.lock b/Cargo.lock index ffc98baa9..92367d1ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -48,9 +48,9 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" [[package]] name = "anstream" -version = "0.6.14" +version = "0.6.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" dependencies = [ "anstyle", "anstyle-parse", @@ -63,33 +63,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anstyle-parse" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", "windows-sys 0.52.0", @@ -121,7 +121,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -160,18 +160,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.7.3" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7d844e282b4b56750b2d4e893b2205581ded8709fddd2b6aa5418c150ca877" +checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.18.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a2c29203f6bf296d01141cc8bb9dbd5ecd4c27843f2ee0767bcd5985a927da" +checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" dependencies = [ "bindgen", "cc", @@ -272,7 +272,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -302,9 +302,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "itoa", "matchit", @@ -352,7 +352,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -433,7 +433,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.72", "which", ] @@ -472,9 +472,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415f8399438eb5e4b2f73ed3152a3448b98149dda642a957ee704e1daa5cf1d8" +checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" [[package]] name = "block-buffer" @@ -487,9 +487,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" +checksum = "236e6289eda5a812bc6b53c3b024039382a2895fbbeef2d748b2931546d392c4" [[package]] name = "bumpalo" @@ -523,9 +523,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "camino" @@ -567,13 +567,12 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] name = "cc" -version = "1.0.101" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac367972e516d45567c7eafc73d24e1c193dcf200a8d94e9db7b3d38b349572d" +checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" dependencies = [ "jobserver", "libc", - "once_cell", ] [[package]] @@ -620,9 +619,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" +checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" dependencies = [ "clap_builder", "clap_derive", @@ -630,9 +629,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" +checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" dependencies = [ "anstream", "anstyle", @@ -642,21 +641,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.5" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" +checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" @@ -667,6 +666,16 @@ dependencies = [ "cc", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -675,9 +684,9 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "colorchoice" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "console" @@ -769,7 +778,7 @@ dependencies = [ "bitflags 2.6.0", "crossterm_winapi", "libc", - "mio", + "mio 0.8.11", "parking_lot", "signal-hook", "signal-hook-mio", @@ -801,6 +810,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctrlc" version = "3.4.4" @@ -812,10 +842,54 @@ dependencies = [ ] [[package]] -name = "darling" -version = "0.20.9" +name = "cxx" +version = "1.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.72", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -823,27 +897,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "darling_macro" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -873,7 +947,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -883,7 +957,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1158,7 +1232,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1412,9 +1486,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1429,7 +1503,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1447,9 +1521,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -1471,16 +1545,16 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1498,10 +1572,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "log", - "rustls 0.23.10", + "rustls 0.23.12", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1515,7 +1589,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1528,7 +1602,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.29", + "hyper 0.14.30", "native-tls", "tokio", "tokio-native-tls", @@ -1536,16 +1610,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", - "hyper 1.3.1", + "http-body 1.0.1", + "hyper 1.4.1", "pin-project-lite", "socket2", "tokio", @@ -1572,12 +1646,12 @@ dependencies = [ [[package]] name = "image" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif", @@ -1595,12 +1669,12 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +checksum = "f79afb8cbee2ef20f59ccd477a218c12a93943d075b492015ecb1bb81f8ee904" dependencies = [ "byteorder-lite", - "thiserror", + "quick-error", ] [[package]] @@ -1679,7 +1753,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1690,9 +1764,9 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is_terminal_polyfill" -version = "1.70.0" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "iso8601" @@ -1738,9 +1812,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1827,12 +1901,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1851,6 +1925,15 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1869,9 +1952,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loop9" @@ -1926,7 +2009,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" dependencies = [ "cfg-if", - "rayon", ] [[package]] @@ -1947,13 +2029,13 @@ dependencies = [ [[package]] name = "metrics-exporter-prometheus" -version = "0.15.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0af7a0d7ced10c0151f870e5e3f3f8bc9ffc5992d32873566ca1f9169ae776" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ "base64 0.22.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-rustls", "hyper-util", "indexmap 2.2.6", @@ -1989,9 +2071,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" dependencies = [ "mime", "unicase", @@ -1999,18 +2081,18 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" +checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", ] [[package]] name = "minijinja-contrib" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" dependencies = [ "minijinja", "serde", @@ -2044,6 +2126,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "mirai-annotations" version = "1.12.0" @@ -2068,7 +2162,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2134,7 +2228,7 @@ dependencies = [ "bytes", "futures", "hostname", - "hyper 0.14.29", + "hyper 0.14.30", "muxado", "once_cell", "parking_lot", @@ -2219,9 +2313,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2256,7 +2350,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2327,9 +2421,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" dependencies = [ "memchr", ] @@ -2364,9 +2458,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -2385,7 +2479,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2396,9 +2490,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", @@ -2432,6 +2526,20 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2525,7 +2633,27 @@ dependencies = [ "glob", "once_cell", "opentelemetry 0.21.0", - "ordered-float 4.2.0", + "ordered-float 4.2.2", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "lazy_static", + "once_cell", + "opentelemetry 0.23.0", + "ordered-float 4.2.2", "percent-encoding", "rand", "thiserror", @@ -2548,9 +2676,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6" dependencies = [ "num-traits", ] @@ -2592,7 +2720,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2634,7 +2762,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2670,9 +2798,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "powerfmt" @@ -2682,9 +2810,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +dependencies = [ + "zerocopy 0.6.6", +] [[package]] name = "prettyplease" @@ -2693,7 +2824,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2745,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2785,7 +2916,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.68", + "syn 2.0.72", "tempfile", ] @@ -2812,7 +2943,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2947,24 +3078,23 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.7" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67376f469e7e7840d0040bbf4b9b3334005bb167f814621326e4c7ab8cd6e944" +checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" dependencies = [ "avif-serialize", "imgref", "loop9", "quick-error", "rav1e", - "rayon", "rgb", ] [[package]] name = "raw-cpuid" -version = "11.0.2" +version = "11.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" dependencies = [ "bitflags 2.6.0", ] @@ -3002,9 +3132,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ "bitflags 2.6.0", ] @@ -3078,7 +3208,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-tls", "ipnet", "js-sys", @@ -3106,9 +3236,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.37" +version = "0.8.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" dependencies = [ "bytemuck", ] @@ -3145,9 +3275,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -3156,22 +3286,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb9f96e283ec64401f30d3df8ee2aaeb2561f34c824381efa24a35f79bf40ee4" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.68", + "syn 2.0.72", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c74a686185620830701348de757fd36bef4aa9680fd23c49fc539ddcc1af32" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3239,9 +3369,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.10" +version = "0.23.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ "aws-lc-rs", "log", @@ -3254,9 +3384,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3292,9 +3422,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3338,6 +3468,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3350,9 +3486,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3363,9 +3499,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3382,31 +3518,32 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3423,9 +3560,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -3480,12 +3617,12 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", - "mio", + "mio 0.8.11", "signal-hook", ] @@ -3605,7 +3742,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -3627,9 +3764,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -3650,15 +3787,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3722,9 +3860,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" [[package]] name = "tempfile" @@ -3738,9 +3876,40 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "text-generation-backends-trtllm" +version = "2.2.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "clap", + "cmake", + "cxx", + "cxx-build", + "log", + "pkg-config", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tracing", + "tracing-opentelemetry 0.24.0", + "tracing-subscriber", +] + [[package]] name = "text-generation-benchmark" -version = "2.1.2-dev0" +version = "2.2.1-dev0" dependencies = [ "average", "clap", @@ -3761,7 +3930,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.1.2-dev0" +version = "2.2.1-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -3779,7 +3948,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.1.2-dev0" +version = "2.2.1-dev0" dependencies = [ "clap", "ctrlc", @@ -3798,13 +3967,15 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.1.2-dev0" +version = "2.2.1-dev0" dependencies = [ "async-stream", + "async-trait", "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", "clap", + "csv", "futures", "futures-util", "hf-hub", @@ -3826,7 +3997,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "text-generation-client", + "sysinfo", "thiserror", "tokenizers", "tokio", @@ -3835,29 +4006,79 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", + "ureq", "utoipa", "utoipa-swagger-ui", + "uuid", "vergen", ] +[[package]] +name = "text-generation-router-v3" +version = "2.2.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.5", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -3916,9 +4137,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -3964,21 +4185,20 @@ dependencies = [ [[package]] name = "tokio" -version = "1.38.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", - "mio", - "num_cpus", + "mio 1.0.1", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3993,13 +4213,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4029,7 +4249,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.10", + "rustls 0.23.12", "rustls-pki-types", "tokio", ] @@ -4061,9 +4281,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" dependencies = [ "serde", "serde_spanned", @@ -4073,18 +4293,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" dependencies = [ "indexmap 2.2.6", "serde", @@ -4108,7 +4328,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4135,7 +4355,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4158,7 +4378,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4190,7 +4410,7 @@ dependencies = [ "bitflags 2.6.0", "bytes", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tower-layer", @@ -4229,7 +4449,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4295,7 +4515,25 @@ dependencies = [ "tracing-core", "tracing-log 0.2.0", "tracing-subscriber", - "web-time", + "web-time 0.2.4", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", ] [[package]] @@ -4487,7 +4725,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4508,9 +4746,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] [[package]] name = "v_frame" @@ -4537,9 +4791,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.3.1" +version = "8.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", "cargo_metadata", @@ -4559,9 +4813,9 @@ checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -4609,7 +4863,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-shared", ] @@ -4643,7 +4897,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4674,6 +4928,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" @@ -4749,7 +5013,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4758,7 +5022,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4785,7 +5049,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4820,18 +5084,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4848,9 +5112,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4866,9 +5130,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4884,15 +5148,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4908,9 +5172,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4926,9 +5190,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4944,9 +5208,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4962,15 +5226,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" dependencies = [ "memchr", ] @@ -4987,22 +5251,43 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" dependencies = [ - "zerocopy-derive", + "byteorder", + "zerocopy-derive 0.6.6", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive 0.7.35", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", ] [[package]] @@ -5022,7 +5307,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -5054,9 +5339,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 3866a8b31..8bf75b902 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,24 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", - "launcher" + "benchmark", + "backends/v3", + "backends/grpc-metadata", + "backends/trtllm", + "backends/client", + "launcher" +] +default-members = [ + "benchmark", + "backends/v3", + "backends/grpc-metadata", + # "backends/trtllm", + "backends/client", + "launcher" ] resolver = "2" [workspace.package] -version = "2.1.2-dev0" +version = "2.2.1-dev0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" @@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference" base64 = "0.22.0" tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } [profile.release] incremental = true diff --git a/Dockerfile b/Dockerfile index 3f2e8ef01..0d57e38da 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt @@ -41,7 +43,7 @@ RUN cargo build --profile release-opt FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 -ARG PYTORCH_VERSION=2.3.0 +ARG PYTORCH_VERSION=2.4.0 ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml @@ -140,13 +142,6 @@ COPY server/Makefile-eetq Makefile # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq -# Build marlin kernels -FROM kernel-builder AS marlin-kernels-builder -WORKDIR /usr/src -COPY server/marlin/ . -# Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build - # Build Lorax Punica kernels FROM kernel-builder AS lorax-punica-builder WORKDIR /usr/src @@ -161,6 +156,15 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build FBGEMM CUDA kernels +FROM kernel-builder AS fbgemm-builder + +WORKDIR /usr/src + +COPY server/Makefile-fbgemm Makefile + +RUN make build-fbgemm + # Build vllm CUDA kernels FROM kernel-builder AS vllm-builder @@ -222,13 +226,10 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31 COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from marlin kernels builder -COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - -# Copy builds artifacts from vllm builder +# Copy build artifacts from fbgemm builder +COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages @@ -243,7 +244,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ + pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 diff --git a/Dockerfile.trtllm b/Dockerfile.trtllm new file mode 100644 index 000000000..4543ae804 --- /dev/null +++ b/Dockerfile.trtllm @@ -0,0 +1,23 @@ +# All the tooling for CUDA +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +WORKDIR /usr/src/tgi/backends/trtllm +RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget + +COPY . /usr/src/tgi +RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh +RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include . +RUN cmake --build build --parallel -t tgi_trtllm_backend_impl + +# All the tooling for Rust +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +WORKDIR /usr/src + +# Include CUDA related libraries and tools to the Rust based image +COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda +COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build +ENV PATH=/usr/local/cuda/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3 diff --git a/Dockerfile_amd b/Dockerfile_amd index 0aebeee57..51231638c 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile_intel b/Dockerfile_intel index 6a803a32b..d20f0a012 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Makefile b/Makefile index a1399b6d7..3068a06f4 100644 --- a/Makefile +++ b/Makefile @@ -5,13 +5,13 @@ install-server-cpu: cd server && make install-server install-router: - cd router && cargo install --path . + cargo install --path backends/v3/ install-launcher: - cd launcher && cargo install --path . + cargo install --path launcher/ install-benchmark: - cd benchmark && cargo install --path . + cargo install --path benchmark/ install: install-server install-router install-launcher diff --git a/README.md b/README.md index 4287c1195..a88e0437c 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ model=HuggingFaceH4/zephyr-7b-beta volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1 --model-id $model + ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model ``` And then you can make requests like @@ -94,7 +94,7 @@ curl 127.0.0.1:8080/generate_stream \ **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.1-rocm --model-id $model` instead of the command above. +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 100% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml diff --git a/router/client/build.rs b/backends/client/build.rs similarity index 100% rename from router/client/build.rs rename to backends/client/build.rs diff --git a/router/client/src/lib.rs b/backends/client/src/lib.rs similarity index 100% rename from router/client/src/lib.rs rename to backends/client/src/lib.rs diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs similarity index 100% rename from router/client/src/v2/client.rs rename to backends/client/src/v2/client.rs diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs similarity index 100% rename from router/client/src/v2/mod.rs rename to backends/client/src/v2/mod.rs diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 100% rename from router/client/src/v2/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs diff --git a/router/client/src/v3/client.rs b/backends/client/src/v3/client.rs similarity index 100% rename from router/client/src/v3/client.rs rename to backends/client/src/v3/client.rs diff --git a/router/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs similarity index 100% rename from router/client/src/v3/mod.rs rename to backends/client/src/v3/mod.rs diff --git a/router/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs similarity index 100% rename from router/client/src/v3/sharded_client.rs rename to backends/client/src/v3/sharded_client.rs diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 000000000..425b2d7b9 --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,63 @@ +cmake_minimum_required(VERSION 3.20) + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) +include(ExternalProject) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") +set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") + +# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features +find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) + +#### External dependencies #### +include(cmake/fmt.cmake) +include(cmake/json.cmake) +include(cmake/spdlog.cmake) +include(cmake/trtllm.cmake) + +# Let's build TRTLLM as part of CMake +add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") + +# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so +set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE) + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h) +include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt) + +# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back +install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp) + # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + # catch_discover_tests(tgi_trtllm_backend_tests) +endif () diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 000000000..7079d3d11 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +cxx = "1.0" +text-generation-router = { path = "../../router" } +tokenizers = { version = "0.19", features = ["hf-hub"] } +tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +clap = { version = "4.5", features = ["derive"] } +thiserror = "1.0.62" +tracing = "0.1" +tracing-opentelemetry = "0.24" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +log = { version = "0.4", features = [] } + +[build-dependencies] +cmake = "0.1" +cxx-build = { version = "1.0", features = ["parallel"] } +pkg-config = "0.3" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile new file mode 100644 index 000000000..60ad03f72 --- /dev/null +++ b/backends/trtllm/Dockerfile @@ -0,0 +1,100 @@ +ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" +ARG OMPI_VERSION="4.1.6" + +# Build dependencies resolver stage +FROM lukemathwalker/cargo-chef:latest AS chef +WORKDIR /usr/src/text-generation-inference + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# CUDA dependent dependencies resolver stage +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt update && apt install -y \ + build-essential \ + cmake \ + curl \ + gcc \ + g++ \ + git \ + git-lfs \ + libssl-dev \ + ninja-build \ + pkg-config \ + python3 \ + python3-setuptools \ + tar \ + wget + +ENV TGI_INSTALL_PREFIX=/usr/local/tgi +ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt + +# Install OpenMPI +FROM cuda-builder AS mpi-builder +ARG OMPI_VERSION + +ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" +RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ + mkdir /usr/src/mpi && \ + tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ + cd /usr/src/mpi && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + make -j all && \ + make install && \ + rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" + +# Install TensorRT +FROM cuda-builder AS trt-builder +COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh +RUN chmod +x /opt/install_tensorrt.sh && \ + /opt/install_tensorrt.sh + +# Build Backend +FROM cuda-builder AS tgi-builder +WORKDIR /usr/src/text-generation-inference + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ + chmod -R a+w /root/.rustup && \ + chmod -R a+w /root/.cargo + +ENV PATH="/root/.cargo/bin:$PATH" +RUN cargo install cargo-chef + +# Cache dependencies +COPY --from=planner /usr/src/text-generation-inference/recipe.json . +RUN cargo chef cook --release --recipe-path recipe.json + +# Build actual TGI +ARG CUDA_ARCH_LIST +ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH" +ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" +ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" + +COPY . . +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + +FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime +WORKDIR /usr/local/tgi/bin + +ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" + +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi +COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher + +FROM runtime + +LABEL co.huggingface.vendor="Hugging Face Inc." +LABEL org.opencontainers.image.authors="hardware@hf.co" + +ENTRYPOINT ["./text-generation-launcher"] +CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md new file mode 100644 index 000000000..94064504d --- /dev/null +++ b/backends/trtllm/README.md @@ -0,0 +1,46 @@ +# Text Generation Inference - TensorRT-LLM Backend Implementation + +## Description + +This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API + +## Simplified Request Sequence + +```mermaid +sequenceDiagram + actor User + participant TextGenerationInference.HttpServer + participant TextGenerationInference.TensorRtLlmBackend + participant TextGenerationInference.TensorRtLlmWorkerThread + participant TensorRtLlm.Executor + participant Nvidia.Gpu + User ->> TextGenerationInference.HttpServer: POST /generate + TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters + TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher + activate Nvidia.Gpu + TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier + rect rgb(10, 92, 54) + loop every 100us + rect rgb(15, 81, 50) + alt Acquire lock to query executor + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated + else There are new generated tokens + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted) + rect rgb(11, 110, 79) + alt Generated token is final + TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection + else Generated token is not final + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded + end + end + end + end + deactivate Nvidia.Gpu + end + end + +``` diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 000000000..086382624 --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,150 @@ +use cxx_build::CFG; +use pkg_config; +use std::env; +use std::env::consts::ARCH; +use std::path::{absolute, PathBuf}; + +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.5"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); +const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); + +// Dependencies +const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; +const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt_llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), +]; + +macro_rules! probe { + ($name: expr, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("{}-{}", $name, $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + +fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { + // Build the backend implementation through CMake + let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); + let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); + let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default + + let mut install_path = PathBuf::from(install_path); + if !install_path.is_absolute() { + install_path = absolute(out_dir).expect("cannot happen").join(install_path); + } + + let _ = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .profile(match is_debug { + true => "Debug", + false => "Release", + }) + .env("OPT_LEVEL", opt_level) + .define("CMAKE_INSTALL_PREFIX", &install_path) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") + .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path) + .build(); + + // Additional transitive CMake dependencies + let deps_folder = out_dir.join("build").join("_deps"); + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_name = match is_debug { + true => format!("{}d", dependency), + false => String::from(dependency), + }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information from the artifacts we just built + let install_lib_path = install_path.join("lib"); + + println!( + r"cargo:warning=Adding link search path: {}", + install_lib_path.display() + ); + println!(r"cargo:rustc-link-search={}", install_lib_path.display()); + + (PathBuf::from(install_path), deps_folder) +} + +fn build_ffi_layer(deps_folder: &PathBuf) { + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .static_flag(true) + .include(deps_folder.join("fmt-src").join("include")) + .include(deps_folder.join("spdlog-src").join("include")) + .include(deps_folder.join("json-src").join("include")) + .include(deps_folder.join("trtllm-src").join("cpp").join("include")) + .include("/usr/local/cuda/include") + .include("/usr/local/tensorrt/include") + .file("src/ffi.cpp") + .std("c++20") + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + println!("cargo:rerun-if-changed=include/ffi.h"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); +} + +fn main() { + // Misc variables + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + let (is_debug, opt_level) = match build_profile.as_ref() { + "debug" => (true, "0"), + _ => (false, "3"), + }; + + // Build the backend + let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); + + // Build the FFI layer calling the backend above + build_ffi_layer(&deps_folder); + + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); + + // Probe CUDA & co. with pkg-config + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); + + // NCCL is slightly trickier because it might not have a pkgconfig installed + let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH); + let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default); + println!(r"cargo:rustc-link-search=native={}", nccl_library_path); + println!("cargo:rustc-link-lib=dylib=nccl"); + + // TensorRT + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); + println!("cargo:rustc-link-lib=dylib=nvinfer"); + + // TensorRT-LLM + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); + + // Backend + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); +} diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake new file mode 100644 index 000000000..f94a9c566 --- /dev/null +++ b/backends/trtllm/cmake/fmt.cmake @@ -0,0 +1,6 @@ +FetchContent_Declare( + fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt + GIT_TAG 11.0.1 +) +FetchContent_MakeAvailable(fmt) diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake new file mode 100644 index 000000000..29e5753b3 --- /dev/null +++ b/backends/trtllm/cmake/json.cmake @@ -0,0 +1,5 @@ +fetchcontent_declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz +) +fetchcontent_makeavailable(json) diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake new file mode 100644 index 000000000..c4ee5c97a --- /dev/null +++ b/backends/trtllm/cmake/spdlog.cmake @@ -0,0 +1,17 @@ +set(SPDLOG_USE_FMT ON) +set(SPDLOG_BUILD_SHARED OFF) +set(SPDLOG_FMT_EXTERNAL ON) + +# Define the level at which SPDLOG_ compilation level is defined +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) +else () + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) +endif () + +fetchcontent_declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v1.14.1 +) +fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake new file mode 100644 index 000000000..e59ad4cf3 --- /dev/null +++ b/backends/trtllm/cmake/trtllm.cmake @@ -0,0 +1,42 @@ +set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}) + +set(USE_CXX11_ABI ON) +set(BUILD_PYT OFF) +set(BUILD_PYBIND OFF) +set(BUILD_MICRO_BENCHMARKS OFF) +set(BUILD_BENCHMARKS OFF) +set(BUILD_TESTS OFF) +set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST}) + +message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(FAST_BUILD ON) + set(NVTX_DISABLE OFF) +else () + set(FAST_BUILD OFF) + set(FAST_MATH ON) + set(NVTX_DISABLE ON) +endif () + +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 + GIT_SHALLOW FALSE +) +fetchcontent_makeavailable(trtllm) + +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") +execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") +execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") + +# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here +set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name") +set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}" + CACHE INTERNAL "nvrtc wrapper library path") + +# The same Executor Static library +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name") +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path") diff --git a/backends/trtllm/cmake/utils/detect_cuda_arch.cu b/backends/trtllm/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 000000000..e69de29bb diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 000000000..7990e76b9 --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,121 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include +#include +#include +#include + +#include + +#include +#include +#include + +using json = nlohmann::json; +namespace tle = tensorrt_llm::executor; + +namespace huggingface::tgi::backends { + using RequestId = tle::IdType; + using TokenId = tle::TokenIdType; + + /** + * Initialize all the components required by TRTLLM. + * It is required to call this function before attempting to load any engine + */ + void InitializeBackend(); + + /** + * + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode + * @return + */ + tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); + + /** + * Get the sampling configuration from the parameters provided by TGI + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + tle::SamplingConfig GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + */ + class TensorRtLlmBackend { + private: + const json config; + tle::Executor executor; + + public: + explicit TensorRtLlmBackend( + const std::filesystem::path &engineFolder, + const std::filesystem::path &executorWorker + ); + + /** + * Indicate if the backend is ready to accept incoming request + * @return true if ready, false otherwise + */ + [[nodiscard]] bool IsReady() const; + + /** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + + /** + * Submit a new generation task to the executor + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return Request id related to this generation for reference + */ + [[nodiscard]] RequestId Submit( + const std::vector &tokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + * @param requestId The request id to poll the generation results + * @return + */ + std::vector Poll(RequestId requestId); + + /** + * Stop the underlying executor + */ + void Shutdown(); + }; +} + + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 000000000..fe0be9fc8 --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,75 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +#include +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +#include "backends/trtllm/src/lib.rs.h" + + +namespace huggingface::tgi::backends { + +// struct GenerationContext; + + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @return + */ + bool IsReady() const; + + /*** + * + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t + Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); + + /*** + * + * @param requestId + * @param ctx + * @param callback + * @return + */ + size_t StreamTokens( + const RequestId requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h new file mode 100644 index 000000000..da0bf4f3c --- /dev/null +++ b/backends/trtllm/include/hardware.h @@ -0,0 +1,59 @@ +// +// Created by mfuntowicz on 7/23/24. +// + +#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H +#define TGI_TRTLLM_BACKEND_HARDWARE_H + +#include +#include +#include +#include +#include + +namespace huggingface::hardware::cuda { + +#define AMPERE_SM_MAJOR 8 +#define HOPPER_SM_MAJOR 8 + + /** + * Store information about the version of the CUDA Compute Capabilities detected on the device + */ + struct CudaComputeCapabilities { + int32_t major; + int32_t minor; + + [[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; } + + [[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; } + }; + + CudaComputeCapabilities GetCudaComputeCapabilities() { + // Get the compute capabilities of the current hardware + nvmlDevice_t device; + CudaComputeCapabilities capabilities{0, 0}; + if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) { + SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0"); + if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) { + SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor); + } + } + + return capabilities; + } + + /** + * Return the number of GPU detected. If no GPU is detected, return size_t::max() + * @return + */ + std::optional GetNumDevices() { + uint32_t numGpus = 0; + if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { + return std::optional(numGpus); + } else { + return std::nullopt; + } + } +} + +#endif //TGI_TRTLLM_BACKEND_HARDWARE_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 000000000..c066a6d6e --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include + +#include "backend.h" +#include "hardware.h" + +void huggingface::tgi::backends::InitializeBackend() { + SPDLOG_INFO("Initializing Backend..."); + nvmlInit_v2(); + initTrtLlmPlugins(); + + const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); + if (numGpus.has_value()) { + SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); + } else { + SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system"); + } +} + +[[nodiscard]] +tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { + tle::ExecutorConfig execConfig(1); + + // Retrieve the compute capabilities to enable some options at runtime + const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); + + // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) + if (config["/pretrained_config/mapping/world_size"_json_pointer].get() == 1) { + SPDLOG_INFO("Detected single engine deployment, using leader mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kLEADER, + std::nullopt, + std::nullopt, + std::nullopt + )); + } else { // Multiple engines -> using orchestrator mode (MPI involved) + SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kORCHESTRATOR, + std::nullopt, + std::nullopt, + tle::OrchestratorConfig(true, workerPath, nullptr, true) + )); + } + + // Define some configuration variables + execConfig.setKvCacheConfig(tle::KvCacheConfig(true)); + execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere()); + return execConfig; +} + +tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed) { + return tle::SamplingConfig( + 1, // TGI only use a single beam + topK, + topP, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty + ); +} + +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( + const std::filesystem::path &enginesFolder, + const std::filesystem::path &executorWorker +) : + config(json::parse(std::ifstream(enginesFolder / "config.json"))), + executor( + enginesFolder, + tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string() + )) { + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); +} + +bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { + return executor.canEnqueueRequests(); +} + +[[nodiscard("Returned number of requests needs to be consumed")]] +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { + return executor.getNumResponsesReady(); +} + +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] +tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( + const std::vector &tokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed +) { +#ifdef NDEBUG + SPDLOG_DEBUG( + FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), + tokens.size(), + executor.getLatestIterationStats().back().numActiveRequests + ); +#else + SPDLOG_DEBUG( + FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), + fmt::join(tokens, ", "), + executor.getLatestIterationStats().front().numActiveRequests + ); +#endif + + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); + + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + const auto output = tle::OutputConfig(true, false, false, true, false); + return executor.enqueueRequest( + tle::Request{tokens, maxNewTokens, true, sampling, output}); +} + +[[nodiscard("Generated tokens result must be used")]] +std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { + SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId); + return executor.awaitResponses(requestId); +} + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh new file mode 100755 index 000000000..e0e2dd17b --- /dev/null +++ b/backends/trtllm/scripts/install_tensorrt.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +set -ex + +TRT_VER="10.2.0.19" +CUDA_VER="12.5" +CUDNN_VER="9.2.1.18-1" +NCCL_VER="2.22.3-1+cuda12.5" +CUBLAS_VER="12.5.3.2-1" +NVRTC_VER="12.5.82-1" + +for i in "$@"; do + case $i in + --TRT_VER=?*) TRT_VER="${i#*=}";; + --CUDA_VER=?*) CUDA_VER="${i#*=}";; + --CUDNN_VER=?*) CUDNN_VER="${i#*=}";; + --NCCL_VER=?*) NCCL_VER="${i#*=}";; + --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";; + *) ;; + esac + shift +done + +NVCC_VERSION_OUTPUT=$(nvcc --version) +if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then + echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}." + exit 1 +fi + +install_ubuntu_requirements() { + apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates + ARCH=$(uname -m) + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi + curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb + dpkg -i cuda-keyring_1.0-1_all.deb + + apt-get update + if [[ $(apt list --installed | grep libcudnn9) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcudnn9* + fi + if [[ $(apt list --installed | grep libnccl) ]]; then + apt-get remove --purge -y --allow-change-held-packages libnccl* + fi + if [[ $(apt list --installed | grep libcublas) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcublas* + fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER} + apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} + apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} + apt-get clean + rm -rf /var/lib/apt/lists/* +} + +install_centos_requirements() { + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + yum -y update + yum -y install epel-release + yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER} + yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} + yum clean all +} + +install_tensorrt() { + #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') + #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") + TRT_CUDA_VERSION="12.5" + + if [ -z "$RELEASE_URL_TRT" ];then + ARCH=${TRT_TARGETARCH} + if [ -z "$ARCH" ];then ARCH=$(uname -m);fi + if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + fi + wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + tar -xf /tmp/TensorRT.tar -C /usr/local/ + mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt + # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl + rm -rf /tmp/TensorRT.tar +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + debian) + install_ubuntu_requirements + install_tensorrt + ;; + ubuntu) + install_ubuntu_requirements + install_tensorrt + ;; + centos) + install_centos_requirements + install_tensorrt + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 000000000..b26d06a6d --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,329 @@ +use std::future::Future; +use std::path::Path; +use std::pin::{pin, Pin}; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use async_trait::async_trait; +use cxx::UniquePtr; +use log::{error, warn}; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::RwLock; +use tokio::time::{sleep, Instant}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{Stream, StreamExt}; +use tracing::{instrument, span, Level}; + +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidationError::UnsupportedModality; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; +use text_generation_router::{FinishReason, Token}; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; + +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + +type InferResult = Result; + +pub(crate) struct Generation { + executor: Arc>>, + done: Arc, +} + +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. +pub struct GenerationContext { + sender: UnboundedSender>, + tokenizer: Arc, + tokens: Vec, + done: Arc, + queued: Instant, + start: Option, +} + +impl Stream for Generation { + type Item = usize; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let interval = POLLING_INTERVAL_US.get_or_init(|| { + u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100")) + .expect("Invalid value provided for envvar POLLING_INTERVAL_US") + }); + + if !self.done.load(Ordering::Relaxed) { + let backend = pin!(self.executor.read()); + let status = match backend.poll(ctx) { + Poll::Ready(executor_r) => { + let ready = executor_r.num_responses_ready(); + if ready == 0 { + Poll::Pending + } else { + Poll::Ready(Some(ready)) + } + } + Poll::Pending => Poll::Pending, + }; + + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_micros(*interval)).await; + waker.wake(); + }); + + status + } else { + Poll::Ready(None) // end of stream + } + } + + fn size_hint(&self) -> (usize, Option) { + (1, None) + } +} + +unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} + +/// Implements the logic to execute generation with TensorRT-LLM executor API in background +pub struct TensorRtLlmBackend { + tokenizer: Arc, + + // Backing the backend behind a RwLock to allow concurrent read access to retrieve + // the number of available tokens (read only) in the Generation stream + backend: Arc>>, +} + +impl TensorRtLlmBackend { + pub fn new + Send + 'static, PP: AsRef + Send + 'static>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + ) -> Result { + Ok(TensorRtLlmBackend { + tokenizer: Arc::new(tokenizer), + backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( + engine_folder.as_ref().to_str().unwrap(), + executor_worker_path.as_ref().to_str().unwrap(), + ))), + }) + } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } + + fn generate( + &self, + sender: UnboundedSender>, + tokens: Vec, + top_k: u32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) { + let tokenizer = Arc::clone(&self.tokenizer); + let executor = Arc::clone(&self.backend); + + // Let's push this in async context + tokio::spawn(async move { + // Define the generation state + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + // Define the context over the generation + // TODO(asap): Do we really need so many shared-ownership? + let ctx = Box::new(GenerationContext { + sender: sender.clone(), + tokenizer, + tokens: vec![], + done: Arc::clone(&generation.done), + start: None, + queued: Instant::now(), + }); + + // We are leaking the context on-purpose to avoid the box being dropped while there are + // still computation ongoing + // TODO(asap): Can we achieve the same with an Arc> without the need to go unsafe? + let ctx_ = Box::leak(ctx); + + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "submit") + .in_scope(|| async { + let mut handle = executor.write().await; + let request_id = handle.pin_mut().submit( + &tokens, + top_k as i32, + top_p, + temperature, + repetition_penalty, + frequency_penalty, + seed, + ); + + request_id + }) + .await; + + while let Some(_) = generation.next().await { + let mut executor_w = executor.write().await; + let executor = executor_w.pin_mut(); + + span!(Level::DEBUG, "decode") + .in_scope(|| async { + unsafe { + executor.stream_tokens( + request_id, + ctx_, + |ctx: *mut GenerationContext, step: GenerationStep| { + let inner_ctx = &mut *ctx; + + // Update the timestamp at which the request started effectively + // Can be a bit off, would need to be before the callback, let's see + inner_ctx.start.get_or_insert(Instant::now()); + inner_ctx.done.store(step.is_final, Ordering::Relaxed); + + // Ensure we are not running into errors + let parcel = if !step.has_error { + // Insert the latest generated token to the tracker + inner_ctx.tokens.push(step.token_id); + + // Decode the token + let text = inner_ctx + .tokenizer + .decode(&[step.token_id], true) + .expect("Failed to decode token"); + + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special, + }; + + if step.is_final { + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, + }) + } else { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } + } else { + error!("Error caught while decoding: {}", &step.error_msg); + Err(InferError::GenerationError(step.error_msg)) + }; + + // Send the parcel to the client + inner_ctx + .sender + .send(parcel) + .expect("Failed to sent msg through the channel"); + }, + ); + } + }) + .await; + } + + // "Properly" free the shared context... + // TODO: clean that piece of sh** asap + unsafe { + let _ = Box::from_raw(ctx_); + } + }); + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackend { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> InferResult>> { + // Let's add a few more validation + let input = TensorRtLlmBackend::validate(&request)?; + + // Channel to stream the generated token as they come from the worker thread back to the transport layer + let (sender, receiver) = unbounded_channel(); + + // Unpack parameters + let params = &request.parameters; + + // Preprocess the inputs to send to TRTLLM backend + let encoding = self + .tokenizer + .encode(input.as_str(), true) + .map_err(|e| InferError::GenerationError(e.to_string()))?; + + // Generate the response + self.generate( + sender, + Vec::from(encoding.get_ids()), + params.top_k, + params.top_p, + params.temperature, + params.repetition_penalty, + params.frequency_penalty, + params.seed, + ); + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 000000000..a672d2a40 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Tokenizer error: {0}")] + Tokenizer(String), + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 000000000..d6317a68c --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,84 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include "backends/trtllm/include/ffi.h" + + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} + + +bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { + return TensorRtLlmBackend::IsReady(); +} + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); + return TensorRtLlmBackend::Submit( + std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); +} + +size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( + const uint64_t requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback) { + + size_t numTokens = 0; + for (const auto &item: Poll(requestId)) { + GenerationStep step; + if (!item.hasError()) { + SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); + const auto decoded = item.getResult(); + + const auto token = decoded.outputTokenIds[0][0]; + const auto isFinal = decoded.isFinal; + const auto logProb = decoded.logProbs.value()[0][0]; + + ++numTokens; + + SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); + step = huggingface::tgi::backends::GenerationStep{ + static_cast(token), logProb, isFinal, false, std::move(std::string()) + }; + SPDLOG_DEBUG("\tStreamTokens -> Post callback"); + } else { + // TODO : Return rest::Result with error + const auto what = item.getErrorMsg(); + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); + step = huggingface::tgi::backends::GenerationStep{ + std::numeric_limits::max(), 0.0, true, true, std::move(what) + }; + } + + callback(std::move(ctx), std::move(step)); + } + + return numTokens; +} + +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); + + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 000000000..1a804f889 --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,78 @@ +pub use backend::{GenerationContext, TensorRtLlmBackend}; + +mod backend; +pub mod errors; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + pub struct GenerationStep { + token_id: u32, + log_prob: f32, + is_final: bool, + has_error: bool, + error_msg: String, + } + + extern "Rust" { + type GenerationContext; + } + + unsafe extern "C++" { + include!("backends/trtllm/src/ffi.cpp"); + + /// Represent an instance of the underlying TensorRT-LLM backend + type TensorRtLlmBackendImpl; + + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend + /// + /// # Arguments + /// + /// * `engine_folder`: Path to the folder containing all the TRTLLM engines + /// * `executor_worker`: Path to the TRTLLM executor worker + /// + /// returns: + /// + /// # Examples + /// + /// ``` + /// + /// ``` + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( + engine_folder: &str, + executor_worker: &str, + ) -> UniquePtr; + + // #[rust_name = "is_ready"] + // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; + + #[rust_name = "submit"] + fn Submit( + self: Pin<&mut TensorRtLlmBackendImpl>, + tokens: &[u32], + top_k: i32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) -> u64; + + #[rust_name = "stream_tokens"] + unsafe fn StreamTokens( + self: Pin<&mut TensorRtLlmBackendImpl>, + request_id: u64, + ctx: *mut GenerationContext, + cb: unsafe fn(*mut GenerationContext, GenerationStep), + ) -> usize; + + // #[rust_name = "shutdown"] + // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 000000000..6d6ee1468 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use clap::Parser; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackend; +use text_generation_router::server; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(long, env, required = true)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env)] + auth_token: Option, + #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] + executor_worker: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + messages_api_enabled, + max_client_batch_size, + auth_token, + executor_worker, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + + // Run server + let tokenizer = Tokenizer::from_pretrained( + tokenizer_name.clone(), + Some(FromPretrainedParameters { + revision: revision.clone().unwrap_or(String::from("main")), + user_agent: HashMap::new(), + auth_token, + }), + ) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + + let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + None, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + messages_api_enabled, + true, + max_client_batch_size, + ) + .await?; + Ok(()) +} diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp new file mode 100644 index 000000000..8520065a7 --- /dev/null +++ b/backends/trtllm/tests/infer_test.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 7/2/24. +// +#include +#include +#include "../include/backend.h" + +TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") { + const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/"); + const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker"); + + spdlog::info("Loading config from: {}", absolute(engines).string()); + huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor); +} diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 000000000..5d9a140b0 --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +thiserror = "1.0.48" +tokenizers = { workspace = true} +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 000000000..6d702d144 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs new file mode 100644 index 000000000..d82355dea --- /dev/null +++ b/backends/v3/src/backend.rs @@ -0,0 +1,508 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + + let queue = Queue::new( + requires_padding, + block_size, + window_size, + speculate, + max_batch_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/router/src/infer/v3/block_allocator.rs b/backends/v3/src/block_allocator.rs similarity index 100% rename from router/src/infer/v3/block_allocator.rs rename to backends/v3/src/block_allocator.rs diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs new file mode 100644 index 000000000..c407687b7 --- /dev/null +++ b/backends/v3/src/client/grpc_client.rs @@ -0,0 +1,284 @@ +/// Single shard Client +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs new file mode 100644 index 000000000..755431f46 --- /dev/null +++ b/backends/v3/src/client/mod.rs @@ -0,0 +1,76 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs new file mode 100644 index 000000000..afb13cdc3 --- /dev/null +++ b/backends/v3/src/client/sharded_client.rs @@ -0,0 +1,260 @@ +use crate::client::{ClientError, Result}; +/// Multi shard Client +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs new file mode 100644 index 000000000..a6f891692 --- /dev/null +++ b/backends/v3/src/lib.rs @@ -0,0 +1,142 @@ +mod backend; +mod block_allocator; +mod client; +mod queue; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs new file mode 100644 index 000000000..21952e66e --- /dev/null +++ b/backends/v3/src/main.rs @@ -0,0 +1,204 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::{server, usage_stats}; +use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V3Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/infer/v3/queue.rs b/backends/v3/src/queue.rs similarity index 95% rename from router/src/infer/v3/queue.rs rename to backends/v3/src/queue.rs index 894d9cab4..9427bd60c 100644 --- a/router/src/infer/v3/queue.rs +++ b/backends/v3/src/queue.rs @@ -1,17 +1,17 @@ -use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +use crate::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::client; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::{max, min}; use std::collections::VecDeque; -use text_generation_client::v3::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, }; -use text_generation_client::ChunksToString; -use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -337,8 +337,22 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(Input { - chunks: entry.request.inputs.clone(), + input_chunks: Some(client::Input { + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0a..f82659c9e 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -21,7 +21,7 @@ float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } diff --git a/clients/python/README.md b/clients/python/README.md index bf37508e0..88239aa16 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,3 +1,6 @@ +# Legacy warning ⚠️ +The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`. + # Text Generation The Hugging Face Text Generation Python library provides a convenient way of interfacing with a diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index d7a09c9eb..ca783dcdf 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -19,5 +19,15 @@ DEPRECATION_WARNING = ( "Please use the `InferenceClient` from the `huggingface_hub` package instead." ) -from text_generation.client import Client, AsyncClient -from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient +from text_generation.client import Client, AsyncClient # noqa E402 +from text_generation.inference_api import ( # noqa E402 + InferenceAPIClient, + InferenceAPIAsyncClient, +) + +__all__ = [ + "Client", + "AsyncClient", + "InferenceAPIClient", + "InferenceAPIAsyncClient", +] diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index 93b0de8d4..b3b98ed28 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: List[DeployedModel]: list of all currently deployed models """ resp = requests.get( - f"https://api-inference.huggingface.co/framework/text-generation-inference", + "https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=5, ) diff --git a/docs/openapi.json b/docs/openapi.json index 3e7050abb..ed9b0b961 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.1.2-dev0" + "version": "2.2.1-dev0" }, "paths": { "/": { @@ -909,7 +909,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -1580,16 +1580,11 @@ "type": "object", "required": [ "model_id", - "model_dtype", - "model_device_type", "max_concurrent_requests", "max_best_of", "max_stop_sequences", "max_input_tokens", "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", "validation_workers", "max_client_batch_size", "router", @@ -1601,18 +1596,6 @@ "example": "null", "nullable": true }, - "max_batch_size": { - "type": "integer", - "example": "null", - "nullable": true, - "minimum": 0 - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0 - }, "max_best_of": { "type": "integer", "example": "2", @@ -1644,19 +1627,6 @@ "example": "2048", "minimum": 0 }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, "model_id": { "type": "string", "description": "Model info", @@ -1690,11 +1660,6 @@ "version": { "type": "string", "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" } } }, @@ -2035,6 +2000,14 @@ } } }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { @@ -2055,6 +2028,11 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "default": null, + "nullable": true } ] }, diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 119c5662e..e97c00aa2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -21,6 +21,8 @@ title: Messages API - local: architecture title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f5..01f156489 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -349,6 +349,12 @@ Options: --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] +``` +## API_KEY +```shell + --api-key + [env: API_KEY=] + ``` ## WATERMARK_GAMMA ```shell @@ -424,6 +430,20 @@ Options: [env: LORA_ADAPTERS=] +``` +## USAGE_STATS +```shell + --usage-stats + Control if anonymous usage stats are collected. Options are "on", "off" and "no-stack" Defaul is on + + [env: USAGE_STATS=] + [default: on] + + Possible values: + - on: Default option, usage statistics are collected anonymously + - off: Disables all collection of usage statistics + - no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event + ``` ## HELP ```shell diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 33d85732a..931a9e3ad 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \ + ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \ --model-id $model ``` diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index f9fda863b..b38434900 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:latest-intel \ + ghcr.io/huggingface/text-generation-inference:2.2.0-intel \ --model-id $model --cuda-graphs 0 ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 4de6cb19b..dac37d79d 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index f056baad8..2313c69b7 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:2.1.1 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $model ``` @@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help +docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help ``` diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 2bdd00de6..bc124f319 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ## Supported Models +- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md new file mode 100644 index 000000000..a2c406ecc --- /dev/null +++ b/docs/source/usage_statistics.md @@ -0,0 +1,75 @@ + +# Collection of Usage Statistics + +Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. + +Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. + +## What data is collected + +The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs). +As of release 2.1.2 this is an example of the data collected: + +- From the TGI configuration: +```json +{ + "event_type": "start", + "disable_grammar_support": false, + "max_batch_prefill_tokens": 4096, + "max_batch_size": null, + "max_batch_total_tokens": null, + "max_best_of": 2, + "max_client_batch_size": 4, + "max_concurrent_requests": 128, + "max_input_tokens": 1024, + "max_stop_sequences": 4, + "max_top_n_tokens": 5, + "max_total_tokens": 2048, + "max_waiting_tokens": 20, + "messages_api_enabled": false, + "model_config": { + "model_type": "Bloom" + }, + "revision": null, + "tokenizer_class": "BloomTokenizerFast", + "validation_workers": 2, + "waiting_served_ratio": 1.2, + "docker_label": "latest", + "git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe", + "nvidia_env": { + "name": "NVIDIA A10G", + "pci_bus_id": "00000000:00:1E.0", + "driver_version": "535.183.01", + "pstate": "P8", + "pcie_link_gen_max": "4", + "pcie_link_gen_current": "1", + "temperature_gpu": "31", + "utilization_gpu": "0 %", + "utilization_memory": "0 %", + "memory_total": "23028 MiB", + "memory_free": "22515 MiB", + "memory_used": "0 MiB", + "reset_status_reset_required": "No", + "reset_status_drain_and_reset_recommended": "No", + "compute_cap": "8.6", + "ecc_errors_corrected_volatile_total": "0", + "mig_mode_current": "[N/A]", + "power_draw_instant": "10.86 W", + "power_limit": "300.00 W" + }, + "system_env": { + "cpu_count": 16, + "cpu_type": "AMD EPYC 7R32", + "total_memory": 66681196544, + "architecture": "x86_64", + "platform": "linux-unix-x86_64" + } +} + +``` + +## How to opt-out + +By passing the `--usage-stats` to the text-generation-launcher you can control how much usage statistics are being collected. +`--usage-stats=no-stack` will not emit the stack traces from errors and the error types, but will continue to send start and stop events +`--usage-stats=off` will completely disable everything diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f5f38ac6c..46a8769f7 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -4,7 +4,6 @@ import json import math import os import random -import re import shutil import subprocess import sys @@ -271,7 +270,7 @@ class LauncherHandle: try: await self.client.generate("test") return - except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + except (ClientConnectorError, ClientOSError, ServerDisconnectedError): time.sleep(1) raise RuntimeError("Health check failed") @@ -333,6 +332,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -379,6 +380,14 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) + + print(" ".join(args), file=sys.stderr) env["LOG_LEVEL"] = "info,text_generation_router=debug" @@ -418,6 +427,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) @@ -447,6 +458,12 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) client = docker.from_env() diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json index 53a4ab854..b274992ea 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5625, + "logprob": -10.5703125, "text": " dég" }, { "id": 21543, - "logprob": -0.14770508, + "logprob": -0.14746094, "text": "uster" }, { "id": 447, - "logprob": -1.9287109, + "logprob": -1.9277344, "text": " un" }, { "id": 46341, - "logprob": -15.4609375, + "logprob": -15.421875, "text": " ort" }, { "id": 35567, - "logprob": -7.5585938, + "logprob": -7.5820312, "text": "olan" }, { "id": 15, - "logprob": -1.4003906, + "logprob": -1.4013672, "text": "," }, { "id": 1669, - "logprob": -1.5673828, + "logprob": -1.5664062, "text": " il" }, { "id": 11580, - "logprob": -0.94628906, + "logprob": -0.94189453, "text": " faut" }, { "id": 3913, - "logprob": -3.703125, + "logprob": -3.6816406, "text": " tout" }, { "id": 39261, - "logprob": -1.5732422, + "logprob": -1.7753906, "text": " d'abord" } ], @@ -64,65 +64,66 @@ "tokens": [ { "id": 578, - "logprob": -1.6591797, + "logprob": -1.6318359, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.4492188, + "logprob": -2.4882812, "special": false, "text": " faire" }, { - "id": 159570, - "logprob": -6.6835938, + "id": 7735, + "logprob": -2.4355469, "special": false, - "text": " réch" + "text": " fond" }, { - "id": 810, + "id": 289, "logprob": 0.0, "special": false, - "text": "au" + "text": "re" }, { - "id": 12736, + "id": 693, + "logprob": -2.4472656, + "special": false, + "text": " à" + }, + { + "id": 366, + "logprob": -1.1972656, + "special": false, + "text": " la" + }, + { + "id": 48844, + "logprob": -1.7890625, + "special": false, + "text": " cass" + }, + { + "id": 1744, "logprob": 0.0, "special": false, - "text": "ffer" + "text": "ero" }, { - "id": 1742, - "logprob": -2.5175781, - "special": false, - "text": " au" - }, - { - "id": 6105, - "logprob": -2.0078125, - "special": false, - "text": " bain" - }, - { - "id": 88254, - "logprob": -0.12695312, - "special": false, - "text": "-mar" - }, - { - "id": 641, + "id": 327, "logprob": 0.0, "special": false, - "text": "ie" + "text": "le" }, { "id": 2940, - "logprob": -3.5175781, + "logprob": -1.9335938, "special": false, "text": " avec" } - ] + ], + "top_tokens": null }, - "generated_text": " le faire réchauffer au bain-marie avec" + "generated_text": " le faire fondre à la casserole avec" } diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json index ace734160..9422f27ff 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json @@ -11,7 +11,7 @@ }, { "id": 1669, - "logprob": -5.4414062, + "logprob": -5.4453125, "text": " il" }, { @@ -21,12 +21,12 @@ }, { "id": 3913, - "logprob": -4.3554688, + "logprob": -4.3320312, "text": " tout" }, { "id": 39261, - "logprob": -2.9238281, + "logprob": -2.9160156, "text": " d'abord" } ], @@ -34,65 +34,66 @@ "tokens": [ { "id": 408, - "logprob": -0.07891846, + "logprob": -0.16687012, "special": false, "text": " que" }, { "id": 366, - "logprob": -1.2939453, + "logprob": -1.5517578, "special": false, "text": " la" }, { "id": 8769, - "logprob": -0.3708496, + "logprob": -0.16687012, "special": false, "text": " personne" }, { "id": 1479, - "logprob": -2.2871094, + "logprob": -2.1035156, "special": false, "text": " qui" }, { - "id": 2997, - "logprob": -0.8671875, + "id": 143926, + "logprob": -2.8671875, "special": false, - "text": " vous" + "text": " réalise" }, { - "id": 35977, - "logprob": -1.5097656, + "id": 578, + "logprob": 0.0, "special": false, - "text": " suit" + "text": " le" }, { - "id": 21558, - "logprob": -0.07891846, + "id": 8138, + "logprob": -0.66748047, "special": false, - "text": " ait" + "text": " projet" }, { - "id": 447, - "logprob": -0.12695312, + "id": 795, + "logprob": -1.6279297, "special": false, - "text": " un" + "text": " ne" }, { - "id": 78606, - "logprob": -2.21875, + "id": 9802, + "logprob": -0.47875977, "special": false, - "text": " profil" + "text": " soit" }, { - "id": 3899, - "logprob": -1.3535156, + "id": 1230, + "logprob": 0.0, "special": false, - "text": " bien" + "text": " pas" } - ] + ], + "top_tokens": null }, - "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien" + "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui réalise le projet ne soit pas" } diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json index dd8936afd..b17c889e8 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5390625, + "logprob": -10.546875, "text": " dég" }, { "id": 21543, - "logprob": -0.14758301, + "logprob": -0.14819336, "text": "uster" }, { "id": 447, - "logprob": -1.9296875, + "logprob": -1.9257812, "text": " un" }, { "id": 46341, - "logprob": -15.4453125, + "logprob": -15.4296875, "text": " ort" }, { "id": 35567, - "logprob": -7.59375, + "logprob": -7.5625, "text": "olan" }, { "id": 15, - "logprob": -1.3994141, + "logprob": -1.4199219, "text": "," }, { "id": 1669, - "logprob": -1.578125, + "logprob": -1.5634766, "text": " il" }, { "id": 11580, - "logprob": -0.9453125, + "logprob": -0.9458008, "text": " faut" }, { "id": 3913, - "logprob": -3.7011719, + "logprob": -3.6816406, "text": " tout" }, { "id": 39261, - "logprob": -1.5732422, + "logprob": -1.7753906, "text": " d'abord" } ], @@ -64,65 +64,66 @@ "tokens": [ { "id": 578, - "logprob": -1.6474609, + "logprob": -1.828125, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.5097656, + "logprob": -2.5546875, "special": false, "text": " faire" }, { - "id": 159570, - "logprob": -6.65625, + "id": 7735, + "logprob": -2.4277344, "special": false, - "text": " réch" + "text": " fond" }, { - "id": 810, + "id": 289, "logprob": 0.0, "special": false, - "text": "au" + "text": "re" }, { - "id": 12736, + "id": 693, + "logprob": -2.4472656, + "special": false, + "text": " à" + }, + { + "id": 366, + "logprob": -1.1494141, + "special": false, + "text": " la" + }, + { + "id": 48844, + "logprob": -1.7939453, + "special": false, + "text": " cass" + }, + { + "id": 1744, "logprob": 0.0, "special": false, - "text": "ffer" + "text": "ero" }, { - "id": 1742, - "logprob": -2.5859375, - "special": false, - "text": " au" - }, - { - "id": 6105, - "logprob": -2.03125, - "special": false, - "text": " bain" - }, - { - "id": 88254, - "logprob": -0.12695312, - "special": false, - "text": "-mar" - }, - { - "id": 641, + "id": 327, "logprob": 0.0, "special": false, - "text": "ie" + "text": "le" }, { "id": 2940, - "logprob": -3.5175781, + "logprob": -1.9013672, "special": false, "text": " avec" } - ] + ], + "top_tokens": null }, - "generated_text": " le faire réchauffer au bain-marie avec" + "generated_text": " le faire fondre à la casserole avec" } diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json index 99c33cf75..9f3faffce 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json @@ -1,11 +1,17 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 1, "logprobs": null, "text": " PR for more information?" }, + { + "finish_reason": "length", + "index": 3, + "logprobs": null, + "text": "hd20220811-" + }, { "finish_reason": "length", "index": 0, @@ -17,19 +23,13 @@ "index": 2, "logprobs": null, "text": " severely flawed and often has a substandard" - }, - { - "finish_reason": "length", - "index": 3, - "logprobs": null, - "text": "hd20220811-" } ], - "created": 1713284455, + "created": 1722014725, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native", + "system_fingerprint": "2.2.1-dev0-native", "usage": { "completion_tokens": 36, "prompt_tokens": 8, diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json new file mode 100644 index 000000000..03f903672 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.84375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.34375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.8359375, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.0859375, + "special": false, + "text": " is" + }, + { + "id": 254, + "logprob": -1.5390625, + "special": false, + "text": " the" + }, + { + "id": 1022, + "logprob": -1.1875, + "special": false, + "text": " first" + }, + { + "id": 3458, + "logprob": -0.35546875, + "special": false, + "text": " step" + }, + { + "id": 279, + "logprob": -0.8828125, + "special": false, + "text": " in" + }, + { + "id": 254, + "logprob": -0.71484375, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is the first step in the" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json new file mode 100644 index 000000000..6b45cf6b9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -0,0 +1,53 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 4, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2143, + "logprob": -1.828125, + "special": false, + "text": " sent" + }, + { + "id": 10081, + "logprob": -0.41210938, + "special": false, + "text": " successfully" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 100001, + "logprob": -0.16015625, + "special": true, + "text": "<|end▁of▁sentence|>" + } + ], + "top_tokens": null + }, + "generated_text": "Test request sent successfully." +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json new file mode 100644 index 000000000..e365829a2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.859375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.359375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.83203125, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.125, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5703125, + "special": false, + "text": " a" + }, + { + "id": 3412, + "logprob": -2.578125, + "special": false, + "text": " document" + }, + { + "id": 344, + "logprob": -1.125, + "special": false, + "text": " that" + }, + { + "id": 317, + "logprob": -1.6953125, + "special": false, + "text": " is" + }, + { + "id": 1222, + "logprob": -1.75, + "special": false, + "text": " used" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a document that is used" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.859375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.359375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.83203125, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.125, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5703125, + "special": false, + "text": " a" + }, + { + "id": 3412, + "logprob": -2.578125, + "special": false, + "text": " document" + }, + { + "id": 344, + "logprob": -1.125, + "special": false, + "text": " that" + }, + { + "id": 317, + "logprob": -1.6953125, + "special": false, + "text": " is" + }, + { + "id": 1222, + "logprob": -1.75, + "special": false, + "text": " used" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a document that is used" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.859375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.359375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.83203125, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.125, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5703125, + "special": false, + "text": " a" + }, + { + "id": 3412, + "logprob": -2.578125, + "special": false, + "text": " document" + }, + { + "id": 344, + "logprob": -1.125, + "special": false, + "text": " that" + }, + { + "id": 317, + "logprob": -1.6953125, + "special": false, + "text": " is" + }, + { + "id": 1222, + "logprob": -1.75, + "special": false, + "text": " used" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a document that is used" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.859375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.359375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.83203125, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.125, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5703125, + "special": false, + "text": " a" + }, + { + "id": 3412, + "logprob": -2.578125, + "special": false, + "text": " document" + }, + { + "id": 344, + "logprob": -1.125, + "special": false, + "text": " that" + }, + { + "id": 317, + "logprob": -1.6953125, + "special": false, + "text": " is" + }, + { + "id": 1222, + "logprob": -1.75, + "special": false, + "text": " used" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a document that is used" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json index 80f0d053d..8829f9fe6 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -10.0, + "logprob": -10.0625, "text": "Test" }, { "id": 3853, - "logprob": -10.875, + "logprob": -11.0, "text": " request" } ], @@ -24,7 +24,7 @@ "tokens": [ { "id": 1736, - "logprob": -2.09375, + "logprob": -2.03125, "special": false, "text": " form" }, @@ -42,48 +42,48 @@ }, { "id": 2121, - "logprob": -1.8203125, + "logprob": -1.8125, "special": false, "text": " test" }, { "id": 3853, - "logprob": -0.23242188, + "logprob": -0.24121094, "special": false, "text": " request" }, { "id": 1736, - "logprob": -0.08544922, + "logprob": -0.100097656, "special": false, "text": " form" }, { "id": 603, - "logprob": -0.9375, + "logprob": -0.9453125, "special": false, "text": " is" }, { - "id": 1671, - "logprob": -1.671875, + "id": 476, + "logprob": -1.703125, "special": false, - "text": " used" + "text": " a" }, { - "id": 577, - "logprob": -0.40429688, + "id": 4551, + "logprob": -2.453125, "special": false, - "text": " to" + "text": " document" }, { - "id": 3853, - "logprob": -1.1875, + "id": 674, + "logprob": -0.796875, "special": false, - "text": " request" + "text": " that" } ], "top_tokens": null }, - "generated_text": " form\n\nThe test request form is used to request" + "generated_text": " form\n\nThe test request form is a document that" } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json index 8253dc965..0b840bfda 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -10.0, + "logprob": -10.0625, "text": "Test" }, { "id": 3853, - "logprob": -10.875, + "logprob": -11.0, "text": " request" } ], @@ -24,7 +24,7 @@ "tokens": [ { "id": 7539, - "logprob": -0.73046875, + "logprob": -0.609375, "special": false, "text": " forms" }, @@ -36,7 +36,7 @@ }, { "id": 671, - "logprob": -1.703125, + "logprob": -1.5546875, "special": false, "text": " an" }, @@ -66,24 +66,24 @@ }, { "id": 11859, - "logprob": -1.6953125, + "logprob": -1.953125, "special": false, "text": " lab" }, { "id": 2185, - "logprob": -1.3125, + "logprob": -1.7734375, "special": false, "text": " process" }, { - "id": 578, - "logprob": -1.5, + "id": 235265, + "logprob": 0.0, "special": false, - "text": " and" + "text": "." } ], "top_tokens": null }, - "generated_text": "Test request forms are an essential part of the lab process and" + "generated_text": "Test request forms are an essential part of the lab process." } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json new file mode 100644 index 000000000..1e9a50cf4 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json @@ -0,0 +1,254 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.15625, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.1171875, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.84375, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.546875, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.65625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.84375, + "text": " " + }, + { + "id": 235274, + "logprob": -0.45117188, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07421875, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.4140625, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009536743, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.033203125, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.0002670288, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.75, + "text": "," + }, + { + "id": 7385, + "logprob": -11.625, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.29882812, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16699219, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.5, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.125, + "text": ":" + }, + { + "id": 108, + "logprob": -3.421875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.099121094, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.078125, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.025756836, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.29101562, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035858154, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.1007996e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json new file mode 100644 index 000000000..5c47dd3cc --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json @@ -0,0 +1,1018 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.15625, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.1171875, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.84375, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.546875, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.65625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.84375, + "text": " " + }, + { + "id": 235274, + "logprob": -0.45117188, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07421875, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.4140625, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009536743, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.033203125, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.0002670288, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.75, + "text": "," + }, + { + "id": 7385, + "logprob": -11.625, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.29882812, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16699219, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.5, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.125, + "text": ":" + }, + { + "id": 108, + "logprob": -3.421875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08984375, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.1015625, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0038452148, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.1484833e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 106, + "logprob": -47.25, + "text": "" + }, + { + "id": 1645, + "logprob": -18.875, + "text": "user" + }, + { + "id": 235292, + "logprob": -7.25, + "text": ":" + }, + { + "id": 108, + "logprob": -4.78125, + "text": "\n" + }, + { + "id": 5559, + "logprob": -10.0, + "text": "Write" + }, + { + "id": 476, + "logprob": -0.111816406, + "text": " a" + }, + { + "id": 19592, + "logprob": -2.46875, + "text": " poem" + }, + { + "id": 577, + "logprob": -5.78125, + "text": " to" + }, + { + "id": 1707, + "logprob": -6.375, + "text": " help" + }, + { + "id": 682, + "logprob": -2.125, + "text": " me" + }, + { + "id": 5434, + "logprob": -1.59375, + "text": " remember" + }, + { + "id": 573, + "logprob": -0.62890625, + "text": " the" + }, + { + "id": 1370, + "logprob": -6.625, + "text": " first" + }, + { + "id": 235248, + "logprob": -1.7421875, + "text": " " + }, + { + "id": 235274, + "logprob": -0.44921875, + "text": "1" + }, + { + "id": 235276, + "logprob": -0.07128906, + "text": "0" + }, + { + "id": 6635, + "logprob": -2.109375, + "text": " elements" + }, + { + "id": 611, + "logprob": -0.40429688, + "text": " on" + }, + { + "id": 573, + "logprob": -0.0009918213, + "text": " the" + }, + { + "id": 26163, + "logprob": -0.03540039, + "text": " periodic" + }, + { + "id": 3037, + "logprob": -0.00028800964, + "text": " table" + }, + { + "id": 235269, + "logprob": -4.71875, + "text": "," + }, + { + "id": 7385, + "logprob": -11.875, + "text": " giving" + }, + { + "id": 1853, + "logprob": -4.875, + "text": " each" + }, + { + "id": 5356, + "logprob": -0.38867188, + "text": " element" + }, + { + "id": 1277, + "logprob": -3.65625, + "text": " its" + }, + { + "id": 1997, + "logprob": -4.4375, + "text": " own" + }, + { + "id": 2017, + "logprob": -0.3046875, + "text": " line" + }, + { + "id": 235265, + "logprob": -0.16113281, + "text": "." + }, + { + "id": 107, + "logprob": -25.625, + "text": "" + }, + { + "id": 108, + "logprob": -6.75, + "text": "\n" + }, + { + "id": 106, + "logprob": -39.25, + "text": "" + }, + { + "id": 2516, + "logprob": -32.5, + "text": "model" + }, + { + "id": 235292, + "logprob": -10.1875, + "text": ":" + }, + { + "id": 108, + "logprob": -3.296875, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 688, + "logprob": -0.546875, + "special": false, + "text": "**" + }, + { + "id": 103889, + "logprob": -0.49023438, + "special": false, + "text": "Hydrogen" + }, + { + "id": 190213, + "logprob": -0.48632812, + "special": false, + "text": "**," + }, + { + "id": 2611, + "logprob": -0.58203125, + "special": false, + "text": " light" + }, + { + "id": 578, + "logprob": -0.08886719, + "special": false, + "text": " and" + }, + { + "id": 2223, + "logprob": -1.09375, + "special": false, + "text": " free" + }, + { + "id": 235269, + "logprob": -0.024291992, + "special": false, + "text": "," + }, + { + "id": 108, + "logprob": -0.30664062, + "special": false, + "text": "\n" + }, + { + "id": 688, + "logprob": -0.0035552979, + "special": false, + "text": "**" + }, + { + "id": 1949, + "logprob": -4.220009e-05, + "special": false, + "text": "He" + } + ], + "top_tokens": null + }, + "generated_text": "**Hydrogen**, light and free,\n**He" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json index 7a168b2ea..bc80a0f91 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -11,12 +11,12 @@ }, { "id": 2015, - "logprob": -9.65625, + "logprob": -9.640625, "text": "Test" }, { "id": 3853, - "logprob": -10.3671875, + "logprob": -10.375, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 604, - "logprob": -0.36938477, + "logprob": -0.2824707, "special": false, "text": " for" }, { - "id": 235248, - "logprob": -1.8046875, + "id": 573, + "logprob": -0.19030762, "special": false, - "text": " " + "text": " the" }, { - "id": 235274, - "logprob": -0.46240234, + "id": 16819, + "logprob": -1.4892578, "special": false, - "text": "1" + "text": " detection" }, { - "id": 235284, - "logprob": -1.7460938, + "id": 576, + "logprob": -0.7011719, "special": false, - "text": "2" + "text": " of" }, { - "id": 235265, - "logprob": -1.9443359, + "id": 573, + "logprob": -2.0195312, "special": false, - "text": "." + "text": " the" }, { - "id": 235284, - "logprob": -1.4550781, - "special": false, - "text": "2" - }, - { - "id": 235308, - "logprob": -1.0205078, - "special": false, - "text": "5" - }, - { - "id": 235290, - "logprob": -1.0283203, - "special": false, - "text": "-" - }, - { - "id": 235274, - "logprob": -1.2783203, - "special": false, - "text": "1" - }, - { - "id": 235284, + "id": 8566, "logprob": 0.0, "special": false, - "text": "2" + "text": " presence" + }, + { + "id": 689, + "logprob": -0.16491699, + "special": false, + "text": " or" + }, + { + "id": 14862, + "logprob": 0.0, + "special": false, + "text": " absence" + }, + { + "id": 576, + "logprob": -0.9946289, + "special": false, + "text": " of" + }, + { + "id": 671, + "logprob": -0.5263672, + "special": false, + "text": " an" } ], "top_tokens": null }, - "generated_text": "Test request for 12.25-12" + "generated_text": "Test request for the detection of the presence or absence of an" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json new file mode 100644 index 000000000..85cfb91f1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7900391, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3554688, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0039062, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.4489746, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8100586, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013015747, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json new file mode 100644 index 000000000..bf981e4f1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.5625, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.375, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 25, + "logprob": -0.8984375, + "special": false, + "text": ":" + }, + { + "id": 2209, + "logprob": -2.78125, + "special": false, + "text": " Is" + }, + { + "id": 279, + "logprob": -0.6328125, + "special": false, + "text": " the" + }, + { + "id": 734, + "logprob": -2.703125, + "special": false, + "text": " function" + }, + { + "id": 330, + "logprob": -0.34179688, + "special": false, + "text": " \"" + }, + { + "id": 4110, + "logprob": -2.359375, + "special": false, + "text": "Create" + }, + { + "id": 7575, + "logprob": -2.1875, + "special": false, + "text": "Process" + }, + { + "id": 1, + "logprob": -0.07910156, + "special": false, + "text": "\"" + }, + { + "id": 304, + "logprob": -0.83203125, + "special": false, + "text": " in" + }, + { + "id": 12468, + "logprob": -1.8203125, + "special": false, + "text": " Win" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Is the function \"CreateProcess\" in Win" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json new file mode 100644 index 000000000..36c87c097 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json new file mode 100644 index 000000000..94883de5f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json new file mode 100644 index 000000000..58cacb802 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -0.6645508, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 6527, + "logprob": -2.2324219, + "special": false, + "text": " Could" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 6088, + "logprob": -1.6074219, + "special": false, + "text": " parse" + }, + { + "id": 1243, + "logprob": -1.6298828, + "special": false, + "text": " test" + }, + { + "id": 1206, + "logprob": -0.72558594, + "special": false, + "text": " case" + }, + { + "id": 1024, + "logprob": -0.40429688, + "special": false, + "text": " name" + }, + { + "id": 515, + "logprob": 0.0, + "special": false, + "text": " from" + }, + { + "id": 525, + "logprob": -1.2519531, + "special": false, + "text": " '" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Could not parse test case name from '" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json new file mode 100644 index 000000000..96a40fa42 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json index ab4f30156..e3b5575c4 100644 --- a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json @@ -8,49 +8,49 @@ "tokens": [ { "id": 2502, - "logprob": -1.734375, + "logprob": -1.7890625, "special": false, "text": "image" }, { "id": 2196, - "logprob": -0.5756836, + "logprob": -0.53125, "special": false, "text": " result" }, { "id": 604, - "logprob": -0.007843018, + "logprob": -0.0077209473, "special": false, "text": " for" }, { "id": 12254, - "logprob": -1.7167969, + "logprob": -1.703125, "special": false, "text": " chicken" }, { "id": 611, - "logprob": -0.17053223, + "logprob": -0.21582031, "special": false, "text": " on" }, { "id": 573, - "logprob": -0.7626953, + "logprob": -0.734375, "special": false, "text": " the" }, { "id": 8318, - "logprob": -0.02709961, + "logprob": -0.026000977, "special": false, "text": " beach" }, { "id": 1, - "logprob": -0.20739746, + "logprob": -0.2109375, "special": true, "text": "" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index 89e02c074..164e3cf28 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -11,17 +11,17 @@ }, { "id": 1459, - "logprob": -5.6328125, + "logprob": -5.625, "text": " print" }, { "id": 81, - "logprob": -1.6035156, + "logprob": -1.6064453, "text": "_" }, { "id": 7656, - "logprob": -5.9882812, + "logprob": -5.9921875, "text": "hello" } ], @@ -29,7 +29,7 @@ "tokens": [ { "id": 2262, - "logprob": -0.042999268, + "logprob": -0.045715332, "special": false, "text": "():" }, @@ -59,7 +59,7 @@ }, { "id": 10896, - "logprob": -0.38549805, + "logprob": -0.3659668, "special": false, "text": " World" }, @@ -113,7 +113,7 @@ }, { "id": 426, - "logprob": 0.0, + "logprob": -0.051635742, "special": false, "text": "name" }, @@ -323,7 +323,7 @@ }, { "id": 313, - "logprob": -0.6328125, + "logprob": -0.6933594, "special": false, "text": " \"" }, @@ -387,7 +387,8 @@ "special": false, "text": " print" } - ] + ], + "top_tokens": null }, "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index 381172723..d882b82ac 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -11,12 +11,12 @@ }, { "id": 1489, - "logprob": -5.2617188, + "logprob": -5.265625, "text": " print" }, { "id": 100, - "logprob": -0.38476562, + "logprob": -0.38549805, "text": "_" }, { @@ -29,7 +29,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.296875, + "logprob": -0.31323242, "special": false, "text": "():" }, @@ -53,19 +53,19 @@ }, { "id": 8302, - "logprob": -0.28125, + "logprob": -0.26611328, "special": false, "text": "Hello" }, { "id": 10914, - "logprob": -0.79248047, + "logprob": -0.7817383, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.61816406, + "logprob": -0.6328125, "special": false, "text": "!\")" }, @@ -83,7 +83,7 @@ }, { "id": 610, - "logprob": -0.4091797, + "logprob": -0.4086914, "special": false, "text": "def" }, @@ -113,7 +113,7 @@ }, { "id": 444, - "logprob": -0.21655273, + "logprob": -0.21826172, "special": false, "text": "name" }, @@ -160,16 +160,28 @@ "text": "Hello" }, { - "id": 332, - "logprob": -0.034698486, + "id": 925, + "logprob": -3.3476562, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 120, "logprob": 0.0, "special": false, - "text": " +" + "text": "s" + }, + { + "id": 11571, + "logprob": -0.10021973, + "special": false, + "text": "!\"" + }, + { + "id": 925, + "logprob": 0.0, + "special": false, + "text": " %" }, { "id": 655, @@ -178,22 +190,10 @@ "text": " name" }, { - "id": 494, - "logprob": -0.20141602, - "special": false, - "text": " +" - }, - { - "id": 332, + "id": 46, "logprob": 0.0, "special": false, - "text": " \"" - }, - { - "id": 16013, - "logprob": 0.0, - "special": false, - "text": "!\")" + "text": ")" }, { "id": 222, @@ -251,7 +251,7 @@ }, { "id": 400, - "logprob": 0.0, + "logprob": -0.074279785, "special": false, "text": "age" }, @@ -310,34 +310,22 @@ "text": "Hello" }, { - "id": 332, + "id": 925, "logprob": 0.0, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 120, "logprob": 0.0, "special": false, - "text": " +" + "text": "s" }, { - "id": 655, - "logprob": 0.0, + "id": 49, + "logprob": -0.07891846, "special": false, - "text": " name" - }, - { - "id": 494, - "logprob": 0.0, - "special": false, - "text": " +" - }, - { - "id": 3021, - "logprob": -0.5761719, - "special": false, - "text": " \"," + "text": "," }, { "id": 863, @@ -352,43 +340,55 @@ "text": " are" }, { - "id": 332, + "id": 925, "logprob": 0.0, "special": false, - "text": " \"" + "text": " %" }, { - "id": 494, + "id": 105, "logprob": 0.0, "special": false, - "text": " +" + "text": "d" }, { - "id": 615, + "id": 11339, "logprob": 0.0, "special": false, - "text": " str" + "text": " years" }, { - "id": 45, + "id": 3627, "logprob": 0.0, "special": false, - "text": "(" + "text": " old" }, { - "id": 400, + "id": 11571, "logprob": 0.0, "special": false, - "text": "age" + "text": "!\"" }, { - "id": 46, + "id": 925, "logprob": 0.0, "special": false, - "text": ")" + "text": " %" + }, + { + "id": 327, + "logprob": 0.0, + "special": false, + "text": " (" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" } ], "top_tokens": null }, - "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello %s!\" % name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello %s, you are %d years old!\" % (name" } diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json index 456015059..1fad0b96c 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -36,13 +36,13 @@ }, { "id": 633, - "logprob": -0.09301758, + "logprob": -0.09161377, "special": false, "text": " new" }, { "id": 4480, - "logprob": -0.3322754, + "logprob": -0.26171875, "special": false, "text": " feature" }, diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json new file mode 100644 index 000000000..dfdd2cc3f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.27416992, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.17016602, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.7109375, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.5, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.34204102, + "special": false, + "text": "m" + }, + { + "id": 459, + "logprob": -1.6914062, + "special": false, + "text": " not" + }, + { + "id": 1864, + "logprob": -0.69140625, + "special": false, + "text": " sure" + }, + { + "id": 513, + "logprob": -1.6171875, + "special": false, + "text": " if" + }, + { + "id": 315, + "logprob": -1.3837891, + "special": false, + "text": " I" + }, + { + "id": 541, + "logprob": -1.2226562, + "special": false, + "text": " can" + }, + { + "id": 1567, + "logprob": -1.8652344, + "special": false, + "text": " come" + }, + { + "id": 582, + "logprob": -0.0070228577, + "special": false, + "text": " up" + }, + { + "id": 395, + "logprob": -0.0054092407, + "special": false, + "text": " with" + }, + { + "id": 28705, + "logprob": -0.62597656, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0035572052, + "special": false, + "text": "3" + }, + { + "id": 4842, + "logprob": -0.93603516, + "special": false, + "text": " unique" + }, + { + "id": 3085, + "logprob": -0.028411865, + "special": false, + "text": " words" + }, + { + "id": 369, + "logprob": -1.0400391, + "special": false, + "text": " that" + }, + { + "id": 6685, + "logprob": -0.09710693, + "special": false, + "text": " describe" + }, + { + "id": 528, + "logprob": -0.066467285, + "special": false, + "text": " me" + }, + { + "id": 28725, + "logprob": -1.0722656, + "special": false, + "text": "," + }, + { + "id": 562, + "logprob": -0.33422852, + "special": false, + "text": " but" + }, + { + "id": 315, + "logprob": -0.5136719, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.8989258, + "special": false, + "text": "’" + }, + { + "id": 584, + "logprob": -0.2076416, + "special": false, + "text": "ll" + }, + { + "id": 1464, + "logprob": -0.8808594, + "special": false, + "text": " try" + }, + { + "id": 28723, + "logprob": -0.88427734, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.91064453, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.08105469, + "special": false, + "text": "\n" + }, + { + "id": 28740, + "logprob": -1.8486328, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.111572266, + "special": false, + "text": "." + }, + { + "id": 23626, + "logprob": -3.15625, + "special": false, + "text": " Creative" + }, + { + "id": 13, + "logprob": -0.9194336, + "special": false, + "text": "\n" + }, + { + "id": 28750, + "logprob": -0.24841309, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -9.393692e-05, + "special": false, + "text": "." + }, + { + "id": 6785, + "logprob": -3.1386719, + "special": false, + "text": " Fun" + }, + { + "id": 1780, + "logprob": -0.53564453, + "special": false, + "text": "ny" + }, + { + "id": 13, + "logprob": -0.09033203, + "special": false, + "text": "\n" + }, + { + "id": 28770, + "logprob": -0.00466156, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.00016450882, + "special": false, + "text": "." + } + ] + }, + "generated_text": "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json new file mode 100644 index 000000000..91eb5edff --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json @@ -0,0 +1,53 @@ +{ + "details": { + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 1, + "logprob": -0.49658203, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.0016384125, + "special": false, + "text": " " + }, + { + "id": 1, + "logprob": -1.4931641, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.00075769424, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.25024414, + "special": false, + "text": "1" + }, + { + "id": 28740, + "logprob": -0.2631836, + "special": false, + "text": "1" + }, + { + "id": 2, + "logprob": -0.0003285408, + "special": true, + "text": "" + } + ] + }, + "generated_text": " 11" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json new file mode 100644 index 000000000..130186884 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0800781, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -2.1152344, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -1.6748047, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.097229004, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.16467285, + "special": false, + "text": "." + }, + { + "id": 7615, + "logprob": -2.2246094, + "special": false, + "text": " News" + }, + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.69189453, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.013343811, + "special": false, + "text": " " + }, + { + "id": 28750, + "logprob": -0.011230469, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -0.00096845627, + "special": false, + "text": "." + }, + { + "id": 21095, + "logprob": -2.5605469, + "special": false, + "text": " Blog" + }, + { + "id": 13, + "logprob": -0.19458008, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.031280518, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.0030708313, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0029277802, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.0012350082, + "special": false, + "text": "." + }, + { + "id": 20108, + "logprob": -2.1582031, + "special": false, + "text": " Article" + }, + { + "id": 13, + "logprob": -0.05810547, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.35083008, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.034332275, + "special": false, + "text": " " + }, + { + "id": 28781, + "logprob": -0.009666443, + "special": false, + "text": "4" + }, + { + "id": 28723, + "logprob": -0.0013113022, + "special": false, + "text": "." + }, + { + "id": 8349, + "logprob": -2.6191406, + "special": false, + "text": " Review" + }, + { + "id": 13, + "logprob": -0.04031372, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.45239258, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.045410156, + "special": false, + "text": " " + }, + { + "id": 28782, + "logprob": -0.0041236877, + "special": false, + "text": "5" + }, + { + "id": 28723, + "logprob": -0.0010223389, + "special": false, + "text": "." + }, + { + "id": 5299, + "logprob": -2.8066406, + "special": false, + "text": " Other" + }, + { + "id": 13, + "logprob": -0.12054443, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.44580078, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.4921875, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.3574219, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0039062, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.5859375, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.43481445, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.2783203, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20410156, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json new file mode 100644 index 000000000..8c00dee75 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.31347656, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.27441406, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.2285156, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.4677734, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.31762695, + "special": false, + "text": "m" + }, + { + "id": 264, + "logprob": -1.6865234, + "special": false, + "text": " a" + }, + { + "id": 1215, + "logprob": -3.2695312, + "special": false, + "text": " very" + }, + { + "id": 20640, + "logprob": -3.1230469, + "special": false, + "text": " passionate" + }, + { + "id": 1338, + "logprob": -0.48339844, + "special": false, + "text": " person" + }, + { + "id": 28723, + "logprob": -0.9970703, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.5498047, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -1.1923828, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.080444336, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -1.8271484, + "special": false, + "text": " very" + }, + { + "id": 12215, + "logprob": -2.8847656, + "special": false, + "text": " driven" + }, + { + "id": 28723, + "logprob": -1.0927734, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.4584961, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.5019531, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.030715942, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -0.96972656, + "special": false, + "text": " very" + }, + { + "id": 7798, + "logprob": -2.8847656, + "special": false, + "text": " determined" + }, + { + "id": 28723, + "logprob": -0.27319336, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.56396484, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.011016846, + "special": false, + "text": "\n" + }, + { + "id": 3195, + "logprob": -0.7163086, + "special": false, + "text": "What" + }, + { + "id": 349, + "logprob": -1.1611328, + "special": false, + "text": " is" + }, + { + "id": 574, + "logprob": -0.515625, + "special": false, + "text": " your" + }, + { + "id": 6656, + "logprob": -1.0253906, + "special": false, + "text": " favorite" + }, + { + "id": 1970, + "logprob": -2.1738281, + "special": false, + "text": " thing" + }, + { + "id": 684, + "logprob": -0.48364258, + "special": false, + "text": " about" + }, + { + "id": 1250, + "logprob": -1.8876953, + "special": false, + "text": " being" + }, + { + "id": 264, + "logprob": -0.41967773, + "special": false, + "text": " a" + }, + { + "id": 8626, + "logprob": -2.9160156, + "special": false, + "text": " teacher" + }, + { + "id": 28804, + "logprob": -0.11920166, + "special": false, + "text": "?" + }, + { + "id": 13, + "logprob": -0.023727417, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.010848999, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -1.0566406, + "special": false, + "text": "I" + }, + { + "id": 2016, + "logprob": -0.7163086, + "special": false, + "text": " love" + }, + { + "id": 272, + "logprob": -1.9169922, + "special": false, + "text": " the" + }, + { + "id": 1639, + "logprob": -2.03125, + "special": false, + "text": " fact" + } + ] + }, + "generated_text": "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json index eaba50783..9079b3bdb 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -14,55 +14,55 @@ }, { "id": 187, - "logprob": -0.26953125, + "logprob": -0.35742188, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1953125, + "logprob": -1.1015625, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.53515625, + "logprob": -0.5234375, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.625, + "logprob": -0.55078125, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.6796875, + "logprob": -0.6640625, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0, + "logprob": -2.0625, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3125, + "logprob": -2.375, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0028533936, + "logprob": -0.0029144287, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.265625, + "logprob": -1.2734375, "special": false, "text": " machine" } diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index 85e9a9e04..ef88926ca 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -52,7 +52,7 @@ }, { "id": 9830, - "logprob": -1.65625, + "logprob": -2.25, "special": false, "text": " colors" }, @@ -64,13 +64,13 @@ }, { "id": 329, - "logprob": -2.4375, + "logprob": -2.171875, "special": false, "text": " A" }, { "id": 1180, - "logprob": -1.953125, + "logprob": -2.046875, "special": false, "text": " number" }, diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py index 10df6dbda..1f7a4a596 100644 --- a/integration-tests/models/test_chat_llama.py +++ b/integration-tests/models/test_chat_llama.py @@ -1,7 +1,4 @@ import pytest -import json - -from text_generation.types import GrammarType @pytest.fixture(scope="module") diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 0efb66938..d787873b0 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( chunk = [c.replace("data:", "") for c in chunk] # remove empty strings chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] # parse json chunk = [json.loads(c) for c in chunk] diff --git a/integration-tests/models/test_flash_deepseek_v2.py b/integration-tests/models/test_flash_deepseek_v2.py new file mode 100644 index 000000000..010e08c90 --- /dev/null +++ b/integration-tests/models/test_flash_deepseek_v2.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_deepseek_v2_handle(launcher): + with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_deepseek_v2(flash_deepseek_v2_handle): + await flash_deepseek_v2_handle.health(300) + return flash_deepseek_v2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_load( + flash_deepseek_v2, generate_load, response_snapshot +): + responses = await generate_load( + flash_deepseek_v2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_gemma2.py b/integration-tests/models/test_flash_gemma2.py new file mode 100644 index 000000000..547db4939 --- /dev/null +++ b/integration-tests/models/test_flash_gemma2.py @@ -0,0 +1,46 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma2_handle(launcher): + with launcher("google/gemma-2-9b-it", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma2(flash_gemma2_handle): + await flash_gemma2_handle.health(300) + return flash_gemma2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma2(flash_gemma2, response_snapshot): + response = await flash_gemma2.generate( + "user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.\nmodel:\n", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.generated_text == "**Hydrogen**, light and free,\n**He" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot): + responses = await generate_load( + flash_gemma2, + "user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.\nmodel:\n", + max_new_tokens=10, + n=4, + ) + + assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He" + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py new file mode 100644 index 000000000..fe5df590c --- /dev/null +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_fp8_handle(launcher): + with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_fp8(flash_llama_fp8_handle): + await flash_llama_fp8_handle.health(300) + return flash_llama_fp8_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_fp8, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_marlin_24.py b/integration-tests/models/test_flash_llama_marlin_24.py new file mode 100644 index 000000000..3eb94f02e --- /dev/null +++ b/integration-tests/models/test_flash_llama_marlin_24.py @@ -0,0 +1,66 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_marlin24_handle(launcher): + with launcher( + "nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_marlin(flash_llama_marlin24_handle): + await flash_llama_marlin24_handle.health(300) + return flash_llama_marlin24_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_load( + flash_llama_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 3ead3150b..52ecaed46 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -1,6 +1,4 @@ import pytest -import requests -import io import base64 diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index b7725f0bb..eb573385b 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] - assert ( - generated_texts[0] == " \nAssistant: A rooster stands" - ), f"{response.generated_text}" + assert generated_texts[0] == " \nAssistant: A rooster stands" assert len(generated_texts) == 4 assert generated_texts, all( [text == generated_texts[0] for text in generated_texts] diff --git a/integration-tests/models/test_lora_mistral.py b/integration-tests/models/test_lora_mistral.py new file mode 100644 index 000000000..ccdc14863 --- /dev/null +++ b/integration-tests/models/test_lora_mistral.py @@ -0,0 +1,134 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def lora_mistral_handle(launcher): + with launcher( + "mistralai/Mistral-7B-v0.1", + lora_adapters=[ + "predibase/dbpedia", + "predibase/customer_support", + ], + cuda_graphs=[0], + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def lora_mistral(lora_mistral_handle): + await lora_mistral_handle.health(300) + return lora_mistral_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral(lora_mistral, response_snapshot): + response = await lora_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + assert response.details.generated_tokens == 10 + + +classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:""" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/dbpedia", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == " 11" + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_customer_support_adapter( + lora_mistral, response_snapshot +): + print(lora_mistral.base_url) + print(lora_mistral.headers) + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_customer_support_adapter( + lora_mistral, response_snapshot +): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" + ) + assert data == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 0af3f66ac..f831990a3 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -1,7 +1,4 @@ import pytest -import json - -from text_generation.types import GrammarType @pytest.fixture(scope="module") @@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto( }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice( }, ], ) - assert response.choices[0].message.content == None + assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { "id": 0, @@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( stream=False, ) - assert responses.choices[0].message.content == None + assert responses.choices[0].message.content is None assert responses.choices[0].message.tool_calls == [ { "function": { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d2ca38e5a..8acfda0cb 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,5 +1,8 @@ use clap::{Parser, ValueEnum}; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::{ + api::sync::{Api, ApiBuilder}, + Repo, RepoType, +}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use serde::Deserialize; @@ -25,6 +28,7 @@ mod env_runtime; struct RawConfig { max_position_embeddings: Option, n_positions: Option, + model_type: Option, max_seq_len: Option, } @@ -164,6 +168,33 @@ impl std::fmt::Display for RopeScaling { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum UsageStatsLevel { + /// Default option, usage statistics are collected anonymously + On, + /// Disables all collection of usage statistics + Off, + /// Doesn't send the error stack trace or error type, but allows sending a crash event + NoStack, +} + +impl std::fmt::Display for UsageStatsLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + UsageStatsLevel::On => { + write!(f, "on") + } + UsageStatsLevel::Off => { + write!(f, "off") + } + UsageStatsLevel::NoStack => { + write!(f, "no-stack") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -418,6 +449,10 @@ struct Args { #[clap(long, env)] cors_allow_origin: Vec, + + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] watermark_gamma: Option, #[clap(long, env)] @@ -457,6 +492,12 @@ struct Args { /// startup that will be available to callers via the `adapter_id` field in a request. #[clap(long, env)] lora_adapters: Option, + + /// Control if anonymous usage stats are collected. + /// Options are "on", "off" and "no-stack" + /// Defaul is on. + #[clap(default_value = "on", long, env)] + usage_stats: UsageStatsLevel, } #[derive(Debug)] @@ -1201,6 +1242,10 @@ fn spawn_webserver( args.model_id, ]; + // Pass usage stats flags to router + router_args.push("--usage-stats".to_string()); + router_args.push(args.usage_stats.to_string()); + // Grammar support if args.disable_grammar_support { router_args.push("--disable-grammar-support".to_string()); @@ -1251,6 +1296,11 @@ fn spawn_webserver( router_args.push(origin); } + // API Key + if let Some(api_key) = args.api_key { + router_args.push("--api-key".to_string()); + router_args.push(api_key); + } // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); @@ -1384,7 +1434,13 @@ fn main() -> Result<(), LauncherError> { let mut path = std::path::Path::new(&args.model_id).to_path_buf(); let filename = if !path.exists() { // Assume it's a hub id - let api = Api::new()?; + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; let repo = if let Some(ref revision) = args.revision { api.repo(Repo::with_revision( model_id, @@ -1402,6 +1458,11 @@ fn main() -> Result<(), LauncherError> { let content = std::fs::read_to_string(filename)?; let config: RawConfig = serde_json::from_str(&content)?; + + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("FLASH_DECODING", "1"); + } let config: Config = config.into(); // Quantization usually means you're even more RAM constrained. @@ -1576,6 +1637,10 @@ fn main() -> Result<(), LauncherError> { // Download and convert lora adapters if any if let Some(lora_adapters) = &args.lora_adapters { for adapter in lora_adapters.split(',') { + // skip download if a path is provided + if adapter.contains('=') { + continue; + } download_convert_model( adapter, None, diff --git a/load_tests/orca.py b/load_tests/orca.py new file mode 100644 index 000000000..e69de29bb diff --git a/router/Cargo.toml b/router/Cargo.toml index 60fb5c9d1..1be745464 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -7,25 +7,18 @@ edition.workspace = true authors.workspace = true homepage.workspace = true -[lib] -path = "src/lib.rs" - -[[bin]] -name = "text-generation-router" -path = "src/main.rs" - [dependencies] +async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.23.0" -metrics-exporter-prometheus = { version = "0.15.1", features = [] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" @@ -52,6 +45,11 @@ regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } +sysinfo = "0.30.13" +uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +csv = "1.3.0" +ureq = "=2.9" + [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/client/src/v2/pb/.gitignore b/router/client/src/v2/pb/.gitignore deleted file mode 100644 index 72e8ffc0d..000000000 --- a/router/client/src/v2/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/client/src/v3/pb/.gitignore b/router/client/src/v3/pb/.gitignore deleted file mode 100644 index 72e8ffc0d..000000000 --- a/router/client/src/v3/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/chat_template.rs similarity index 67% rename from router/src/infer/v3/scheduler.rs rename to router/src/infer/chat_template.rs index 26cd95847..24a003528 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/chat_template.rs @@ -1,528 +1,85 @@ -/// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, +use crate::infer::InferError; +use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, }; -use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; -use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{info_span, instrument, Instrument, Span}; +use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; -pub(crate) struct SchedulerV3 { - /// Request queue - queue: Queue, - /// Notify batcher on queue appends - batching_task_notifier: Arc, +/// Raise a exception (custom function) used in the chat templates +pub(crate) fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) } -impl SchedulerV3 { - #[allow(clippy::too_many_arguments)] +#[derive(Clone)] +pub(crate) struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { pub(crate) fn new( - client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, + template: String, + bos_token: Option, + eos_token: Option, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") - } else { - false - }; - let block_size = if flashdecoding { 256 } else { 16 }; - let queue = Queue::new( - requires_padding, - block_size, - window_size, - speculate, - max_batch_total_tokens, - ); - let batching_task_notifier = Arc::new(Notify::new()); + let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); Self { - queue, - batching_task_notifier, + template, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), + use_default_tool_template, } } -} -impl Scheduler for SchedulerV3 { - #[instrument(skip_all)] - fn schedule( + pub(crate) fn apply( &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - block_allocation: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[allow(clippy::too_many_arguments)] -pub(crate) async fn batching_task( - mut client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - queue: Queue, - notifier: Arc, - generation_health: Arc, -) { - // Infinite loop - loop { - // Wait for a notification from the Infer struct - notifier.notified().await; - - // Get the next batch from the queue - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the queue - while let Some((mut entries, batch, span)) = queue - .next_batch( - None, - max_batch_size, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) - .await - { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) - .await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let batch_max_tokens = batch.max_tokens; - let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) - .await - { - // Tracking metrics - if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") - .increment(1); - } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); - } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } } - - // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); - let next_batch_span = - info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) - .await; - waiting_tokens += 1; - } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); - } - } -} - -#[instrument(skip_all)] -async fn prefill( - client: &mut ShardedClient, - batch: Batch, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - - match client.prefill(batch).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); - let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); - None - } - } -} - -#[instrument(skip_all)] -async fn decode( - client: &mut ShardedClient, - batches: Vec, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - - match client.decode(batches).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") - .record(concat_duration.as_secs_f64()); - } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - generation_health.store(false, Ordering::SeqCst); - for id in batch_ids { - let _ = client.clear_cache(Some(id)).await; - } - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); - None - } - } -} - -/// Filter a `batch` and remove all requests not present in `entries` -#[instrument(skip_all)] -async fn filter_batch( - client: &mut ShardedClient, - next_batch: Option, - entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - - let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { - // All requests have been filtered out - // Next batch is now empty - // Clear it from the Python shards cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.clear_cache(Some(id)).await.unwrap(); - None - } else { - // Filter Python shard cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() - } -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries -#[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - let id = generation.request_id; - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&id) - .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - err - }).unwrap_or(true); - if stopped { - entries.remove(&id).expect("ID not found in entries. This is a bug."); - } - }); -} - -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - return Ok(true); - } - - let mut stopped = false; - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens - .ids - .into_iter() - .zip(prefill_tokens.logprobs) - .zip(prefill_tokens.texts) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; - } - - // Create last Token - let tokens_ = generation.tokens.expect("Non empty tokens in generation"); - let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); - let mut iterator = tokens_ - .ids - .into_iter() - .zip(tokens_.logprobs) - .zip(tokens_.texts) - .zip(tokens_.is_special) - .enumerate() - .peekable(); - while let Some((i, (((id, logprob), text), special))) = iterator.next() { - let token = Token { - id, - text, - logprob, - special, - }; - let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { - top_tokens_ - .ids - .iter() - .zip(top_tokens_.logprobs.iter()) - .zip(top_tokens_.texts.iter()) - .zip(top_tokens_.is_special.iter()) - .map(|(((&id, &logprob), text), &special)| Token { - id, - text: text.to_string(), - logprob, - special, - }) - .collect() - } else { - vec![] - }; - match (&generation.generated_text, iterator.peek()) { - (Some(generated_text), None) => { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text: GeneratedText::from(generated_text.clone()), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } - _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } } - } - Ok(stopped) -} + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); -/// Send errors to Infer for all `entries` -#[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); -} - -impl From for GeneratedText { - fn from(value: text_generation_client::v3::GeneratedText) -> Self { - let v3_finish_reason = - text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v3_finish_reason { - text_generation_client::v3::FinishReason::Length => FinishReason::Length, - text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, - }; - - Self { - text: value.text, - generated_tokens: value.generated_tokens, - finish_reason, - seed: value.seed, - } + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) } } // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs deleted file mode 100644 index 4320c1a4d..000000000 --- a/router/src/infer/health.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::Health; - -#[derive(Clone)] -pub(crate) struct HealthCheck { - client: Arc, - generation_health: Arc, -} - -impl HealthCheck { - pub(crate) fn new( - client: Arc, - generation_health: Arc, - ) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - let value = if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards can allocate on device - self.client.device_health().await - } else { - self.client.model_health().await - } - .is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index f3b10450a..534a2647c 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,23 +1,18 @@ -mod health; -pub(crate) mod v2; -pub(crate) mod v3; - -pub(crate) use health::HealthCheck; +// pub(crate) mod v2; +mod chat_template; +pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::GrammarType; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, -}; -use crate::{ - FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, + Message, PrefillToken, Token, }; +use async_trait::async_trait; +use chat_template::ChatTemplate; use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; -use minijinja_contrib::pycompat; - -use serde_json::{json, Map, Value}; -use std::collections::HashMap; +use minijinja::ErrorKind; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; @@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -pub(crate) trait Scheduler { +#[async_trait] +pub trait Backend { fn schedule( &self, request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result; + ) -> Result>, InferError>; + + async fn health(&self, current_health: bool) -> bool; } /// Inference struct @@ -39,18 +36,20 @@ pub(crate) trait Scheduler { pub struct Infer { /// Validation validation: Validation, - /// Request scheduler - scheduler: Arc, + /// Request backend + backend: Arc, /// Chat template chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, + /// Backend health + backend_health: Arc, } impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - scheduler: Arc, + backend: impl Backend + Send + Sync + 'static, validation: Validation, max_concurrent_requests: usize, tokenizer_config: HubTokenizerConfig, @@ -71,18 +70,22 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + // Backend health + let backend_health = Arc::new(AtomicBool::new(false)); + Self { validation, - scheduler, + backend: Arc::new(backend), chat_template, limit_concurrent_requests: semaphore, + backend_health, } } /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, + pub(crate) async fn generate_stream<'a>( + &'a self, request: GenerateRequest, ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore @@ -103,7 +106,10 @@ impl Infer { err })?; - self.scheduler.schedule(valid_request, permit) + let input_length = valid_request.input_length; + let generation_stream = self.backend.schedule(valid_request)?; + + Ok((permit, input_length, generation_stream)) } /// Tokenizer the input @@ -155,7 +161,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -165,6 +171,8 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + let mut stream = Box::pin(stream); + // Iterate on stream while let Some(response) = stream.next().await { match response? { @@ -256,202 +264,15 @@ impl Infer { let best_response = infer_responses.remove(max_index); Ok((best_response, infer_responses)) } -} -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new( - template: String, - bos_token: Option, - eos_token: Option, - ) -> Self { - let mut env = Box::new(Environment::new()); - // enable things like .strip() or .capitalize() - env.set_unknown_method_callback(pycompat::unknown_method_callback); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token: bos_token.map(|token| token.as_str().to_string()), - eos_token: eos_token.map(|token| token.as_str().to_string()), - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::Function { function } => { - let tool = req_tools - .iter() - .find(|tool| tool.function.name == function.name) - .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) - .clone(); - vec![tool] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + #[instrument(skip(self))] + pub(crate) async fn health(&self) -> bool { + let health = self + .backend + .health(self.backend_health.load(Ordering::SeqCst)) + .await; + self.backend_health.store(health, Ordering::SeqCst); + health } } @@ -463,15 +284,15 @@ pub(crate) type GenerateStreamResponse = ( ); #[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, +pub struct GeneratedText { + pub text: String, + pub generated_tokens: u32, + pub finish_reason: FinishReason, + pub seed: Option, } #[derive(Debug)] -pub(crate) enum InferStreamResponse { +pub enum InferStreamResponse { // Optional first message Prefill(Vec), // Intermediate messages diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs new file mode 100644 index 000000000..05027f304 --- /dev/null +++ b/router/src/infer/tool_grammar.rs @@ -0,0 +1,135 @@ +use crate::infer::InferError; +use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; + +pub(crate) struct ToolGrammar {} + +impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + + pub fn apply( + tools: Option>, + tool_choice: ToolChoice, + ) -> Result, InferError> { + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; + + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) + } +} diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 8b4f6bab3..6a91a433f 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,4 +1,4 @@ mod queue; mod scheduler; -pub(crate) use scheduler::SchedulerV2; +pub(crate) use scheduler::BackendV2; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 97379bc53..3d6c36cf4 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, + Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; @@ -18,14 +18,14 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -pub(crate) struct SchedulerV2 { +pub(crate) struct BackendV2 { /// Request queue queue: Queue, /// Notify batcher on queue appends batching_task_notifier: Arc, } -impl SchedulerV2 { +impl BackendV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, @@ -69,7 +69,7 @@ impl SchedulerV2 { } } -impl Scheduler for SchedulerV2 { +impl Backend for BackendV2 { #[instrument(skip_all)] fn schedule( &self, diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs deleted file mode 100644 index f9effab8e..000000000 --- a/router/src/infer/v3/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod block_allocator; -mod queue; -mod scheduler; - -pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/lib.rs b/router/src/lib.rs index f856406d6..14bb8270d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,14 @@ /// Text Generation Inference Webserver pub mod config; -mod infer; +pub mod infer; pub mod server; -mod validation; +pub mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod logging; + +pub mod usage_stats; use serde::{Deserialize, Serialize}; use tracing::warn; @@ -40,13 +43,13 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatTemplate { name: String, template: String, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ChatTemplateVersions { Single(String), @@ -55,7 +58,7 @@ pub enum ChatTemplateVersions { use std::path::Path; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, @@ -146,12 +149,13 @@ pub struct Info { pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, - #[schema(example = "torch.float16")] - pub model_dtype: String, - #[schema(example = "cuda")] - pub model_device_type: String, + // #[schema(example = "torch.float16")] + // pub model_dtype: String, + // #[schema(example = "cuda")] + // pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, + /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, @@ -163,18 +167,11 @@ pub struct Info { pub max_input_tokens: usize, #[schema(example = "2048")] pub max_total_tokens: usize, - #[schema(example = "1.2")] - pub waiting_served_ratio: f32, - #[schema(example = "32000")] - pub max_batch_total_tokens: u32, - #[schema(example = "20")] - pub max_waiting_tokens: usize, - #[schema(nullable = true, example = "null")] - pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, #[schema(example = "32")] pub max_client_batch_size: usize, + /// Router Info #[schema(example = "text-generation-router")] pub router: &'static str, @@ -824,7 +821,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// @@ -846,6 +843,7 @@ pub enum ToolType { OneOf, FunctionName(String), Function { function: FunctionName }, + NoTool, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -853,27 +851,26 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[serde(from = "ToolTypeDeserializer")] pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { - None(Option), - Some(ToolType), + String(String), + ToolType(ToolType), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::None(opt) => match opt.as_deref() { - Some("none") => ToolChoice(None), - Some("auto") => ToolChoice(Some(ToolType::OneOf)), - Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), - None => ToolChoice(Some(ToolType::OneOf)), + ToolTypeDeserializer::String(s) => match s.as_str() { + "none" => ToolChoice(Some(ToolType::NoTool)), + "auto" => ToolChoice(Some(ToolType::OneOf)), + _ => ToolChoice(Some(ToolType::FunctionName(s))), }, - ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -1066,23 +1063,23 @@ impl From for GenerateRequest { #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, } #[derive(Debug, Serialize, ToSchema, Clone)] pub struct Token { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, #[schema(example = "false")] - special: bool, + pub special: bool, } #[derive(Debug, Serialize, ToSchema)] @@ -1100,7 +1097,7 @@ pub struct SimpleToken { #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] -pub(crate) enum FinishReason { +pub enum FinishReason { #[schema(rename = "length")] Length, #[serde(rename = "eos_token")] diff --git a/router/src/logging.rs b/router/src/logging.rs new file mode 100644 index 000000000..5a98ef57b --- /dev/null +++ b/router/src/logging.rs @@ -0,0 +1,81 @@ +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +pub fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_service_name, + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} diff --git a/router/src/main.rs b/router/src/main.rs.back similarity index 88% rename from router/src/main.rs rename to router/src/main.rs.back index 21cd66496..36879aa47 100644 --- a/router/src/main.rs +++ b/router/src/main.rs.back @@ -14,6 +14,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; +use text_generation_router::usage_stats; use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; @@ -76,6 +77,8 @@ struct Args { #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] + api_key: Option, + #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, @@ -87,6 +90,10 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, } #[derive(Debug, Subcommand)] @@ -122,12 +129,15 @@ async fn main() -> Result<(), RouterError> { otlp_endpoint, otlp_service_name, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, messages_api_enabled, disable_grammar_support, max_client_batch_size, + disable_usage_stats, + disable_crash_reports, command, } = args; @@ -210,7 +220,11 @@ async fn main() -> Result<(), RouterError> { } let api = if use_api { if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = Cache::default(); + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); Type::Cache(cache) } else { @@ -320,6 +334,7 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer_class = tokenizer_config.tokenizer_class.clone(); let tokenizer: Option = tokenizer_filename.and_then(|filename| { let mut tokenizer = Tokenizer::from_file(filename).ok(); @@ -374,8 +389,47 @@ async fn main() -> Result<(), RouterError> { } }; + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_class, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + // Run server - server::run( + let result = server::run( master_shard_uds_path, model_info, compat_return_full_text, @@ -395,6 +449,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, addr, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, @@ -406,8 +461,41 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, print_schema_command, ) - .await?; - Ok(()) + .await; + + match result { + Ok(_) => { + if let Some(ref ua) = user_agent { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + }; + Ok(()) + } + Err(e) => { + if let Some(ref ua) = user_agent { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + }; + Err(RouterError::WebServer(e)) + } + } } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: diff --git a/router/src/server.rs b/router/src/server.rs index 4e5af99c5..dcbaa2ada 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,9 +1,7 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v2::SchedulerV2; -use crate::infer::v3::SchedulerV3; -use crate::infer::{HealthCheck, Scheduler}; -use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::infer::tool_grammar::ToolGrammar; +use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, @@ -11,11 +9,11 @@ use crate::kserve::{ }; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, - GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, - HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, - SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, - ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, + GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, + OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, + TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -24,10 +22,10 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -37,14 +35,18 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; +use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; -use text_generation_client::{v2, v3, ClientError, ShardInfo}; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use thiserror::Error; +use tokenizers::processors::template::TemplateProcessing; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -123,12 +125,10 @@ responses( example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(health))] +#[instrument(skip(infer))] /// Health check method -async fn health( - mut health: Extension, -) -> Result<(), (StatusCode, Json)> { - match health.check().await { +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { true => Ok(()), false => Err(( StatusCode::SERVICE_UNAVAILABLE, @@ -429,8 +429,9 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_permit, _input_length, response_stream)) => { let mut index = 0; + let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { index += 1; @@ -812,6 +813,10 @@ async fn completions( } }; + let stream = stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1171,6 +1176,11 @@ async fn chat_completions( span, ) .await; + + let response_stream = response_stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1183,39 +1193,33 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - // gen_text should be valid json - let gen_text_value: Value = - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; + let gen_text_value: Value = serde_json::from_str(&generation.generated_text) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: gen_text_value - .get("function") - .and_then(|f| f.get("_name")) - .and_then(|name| name.as_str()) - .unwrap_or("default_function_name") - .to_string(), - // Serialize the JSON object obtained from "function" to an escaped JSON string - arguments: gen_text_value - .get("function") - .map(|f| { - let mut f_cloned = f.clone(); - if let Value::Object(ref mut props) = f_cloned { - props.remove("_name"); - } - f_cloned - }) - .unwrap_or_default(), + name, + arguments, }, }]; (Some(tool_calls), None) @@ -1392,260 +1396,451 @@ async fn metrics(prom_handle: Extension) -> String { #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); +// OpenAPI documentation +#[derive(OpenApi)] +#[openapi( +paths( +health, +get_model_info, +compat_generate, +generate, +generate_stream, +chat_completions, +completions, +tokenize, +metrics, +), +components( +schemas( +Info, +CompatGenerateRequest, +GenerateRequest, +GrammarType, +ChatRequest, +Message, +MessageContent, +MessageChunk, +Url, +FunctionName, +OutputMessage, +TextMessage, +ToolCallMessage, +ToolCallDelta, +ChatCompletionComplete, +ChatCompletionChoice, +ChatCompletionDelta, +ChatCompletionChunk, +ChatCompletionLogprob, +ChatCompletionLogprobs, +ChatCompletionTopLogprob, +ChatCompletion, +CompletionRequest, +CompletionComplete, +Chunk, +Completion, +CompletionFinal, +Prompt, +GenerateParameters, +PrefillToken, +Token, +GenerateResponse, +TokenizeResponse, +SimpleToken, +BestOfSequence, +Details, +FinishReason, +StreamResponse, +StreamDetails, +ErrorResponse, +GrammarType, +Usage, +DeltaToolCall, +ToolType, +Tool, +ToolCall, +Function, +FunctionDefinition, +ToolChoice, +) +), +tags( +(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") +), +info( +title = "Text Generation Inference", +license( +name = "Apache 2.0", +url = "https://www.apache.org/licenses/LICENSE-2.0" +) +) +)] +pub struct ApiDoc; + +pub fn schema() -> ApiDoc { + ApiDoc +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( - master_shard_uds_path: String, - model_info: HubModelInfo, - compat_return_full_text: bool, + backend: impl Backend + Send + Sync + 'static, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, - tokenizer: Option, - config: Option, validation_workers: usize, - addr: SocketAddr, - allow_origin: Option, + api_key: Option, + tokenizer_name: String, + tokenizer_config_path: Option, + revision: Option, + hostname: String, + port: u16, + cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - tokenizer_config: HubTokenizerConfig, - preprocessor_config: Option, - processor_config: HubProcessorConfig, messages_api_enabled: bool, - grammar_support: bool, + disable_grammar_support: bool, max_client_batch_size: usize, - print_schema_command: bool, + usage_stats_level: usage_stats::UsageStatsLevel, ) -> Result<(), WebServerError> { - // OpenAPI documentation - #[derive(OpenApi)] - #[openapi( - paths( - health, - get_model_info, - compat_generate, - generate, - generate_stream, - chat_completions, - completions, - tokenize, - metrics, - ), - components( - schemas( - Info, - CompatGenerateRequest, - GenerateRequest, - GrammarType, - ChatRequest, - Message, - MessageContent, - MessageChunk, - Url, - FunctionName, - OutputMessage, - TextMessage, - ToolCallMessage, - ToolCallDelta, - ChatCompletionComplete, - ChatCompletionChoice, - ChatCompletionDelta, - ChatCompletionChunk, - ChatCompletionLogprob, - ChatCompletionLogprobs, - ChatCompletionTopLogprob, - ChatCompletion, - CompletionRequest, - CompletionComplete, - Chunk, - Completion, - CompletionFinal, - Prompt, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - TokenizeResponse, - SimpleToken, - BestOfSequence, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - GrammarType, - Usage, - DeltaToolCall, - ToolType, - Tool, - ToolCall, - Function, - FunctionDefinition, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) - )] - struct ApiDoc; + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); - // Create state - if print_schema_command { - let api_doc = ApiDoc::openapi(); - let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); - println!("{}", api_doc); - std::process::exit(0); + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, } - - // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( - Arc, - HealthCheck, - ShardInfo, - u32, - ) = { - // Helper function to check both v2 and v3 - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { - match max_supported_batch_total_tokens { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)); - } - - Ok(max_supported_batch_total_tokens) - } - } - }; - - let generation_health = Arc::new(AtomicBool::new(false)); - - match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { - Ok(mut sharded_client) => { - // server is running on v3 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV3::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V3"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - Err(_) => { - let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - - // server is running on v2 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V2"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) } } + } else { + Type::None }; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), + Some(local_path.join("processor_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), + repo.get("processor_config.json"), + None, + ) + } + }; + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + tokenizer.with_post_processor(post_processor); + } + } + } + } + tokenizer + }); + + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + + let preprocessor_config: Option = + preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + let user_agent = match (usage_stats_level, is_container) { + (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_config.tokenizer_class.clone(), + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, + revision.clone(), + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats_level, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } + _ => None, + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + let result = start( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + config, + (tokenizer, tokenizer_config), + (preprocessor_config, processor_config), + hostname, + port, + ngrok, + _ngrok_authtoken, + _ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + model_info, + compat_return_full_text, + allow_origin, + ) + .await; + + if let Some(ua) = user_agent { + match result { + Ok(_) => { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + Ok(()) + } + Err(e) => { + let description = match usage_stats_level { + usage_stats::UsageStatsLevel::On => Some(e.to_string()), + usage_stats::UsageStatsLevel::NoStack => Some("unknow_error".to_string()), + _ => None, + }; + let event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + description, + ); + event.send().await; + + Err(e) + } + } + } else { + result + } +} + +#[allow(clippy::too_many_arguments)] +async fn start( + backend: impl Backend + Send + Sync + 'static, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + validation_workers: usize, + api_key: Option, + config: Option, + (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (preprocessor_config, processor_config): (Option, HubProcessorConfig), + hostname: String, + port: u16, + ngrok: bool, + _ngrok_authtoken: Option, + _ngrok_edge: Option, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + model_info: HubModelInfo, + compat_return_full_text: bool, + allow_origin: Option, +) -> Result<(), WebServerError> { + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Create state let validation = Validation::new( validation_workers, tokenizer, @@ -1656,11 +1851,11 @@ pub async fn run( max_top_n_tokens, max_input_tokens, max_total_tokens, - grammar_support, + disable_grammar_support, ); let infer = Infer::new( - scheduler, + backend, validation, max_concurrent_requests, tokenizer_config, @@ -1697,8 +1892,8 @@ pub async fn run( let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Speculated tokens buckets - let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); - let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); + // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); + // let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() @@ -1711,9 +1906,9 @@ pub async fn run( .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(skipped_matcher, &skipped_buckets) .unwrap(); + // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) + // .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder"); @@ -1729,18 +1924,18 @@ pub async fn run( let info = Info { model_id: model_info.model_id, model_sha: model_info.sha, - model_dtype: shard_info.dtype, - model_device_type: shard_info.device_type, + // model_dtype: shard_info.dtype, + // model_device_type: shard_info.device_type, model_pipeline_tag: model_info.pipeline_tag, max_concurrent_requests, max_best_of, max_stop_sequences, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, validation_workers, max_client_batch_size, router: env!("CARGO_PKG_NAME"), @@ -1806,16 +2001,42 @@ pub async fn run( let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes - let base_routes = Router::new() + let mut base_routes = Router::new() .route("/", post(compat_generate)) - .route("/", get(health)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + if let Some(api_key) = api_key { + let mut prefix = "Bearer ".to_string(); + prefix.push_str(&api_key); + + // Leak to allow FnMut + let api_key: &'static str = prefix.leak(); + + let auth = move |headers: HeaderMap, + request: axum::extract::Request, + next: axum::middleware::Next| async move { + match headers.get(AUTHORIZATION) { + Some(token) => match token.to_str() { + Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + }, + None => Err(StatusCode::UNAUTHORIZED), + } + }; + + base_routes = base_routes.layer(axum::middleware::from_fn(auth)) + } + let info_routes = Router::new() + .route("/", get(health)) + .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); @@ -1834,6 +2055,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) + .merge(info_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")] @@ -1874,7 +2096,6 @@ pub async fn run( // add layers after routes app = app .layer(Extension(info)) - .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(compute_type)) @@ -1912,6 +2133,68 @@ pub async fn run( Ok(()) } +/// get model info from the Huggingface Hub +pub async fn get_hub_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { @@ -1975,16 +2258,77 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Not enough memory to handle `max_total_tokens={0}`")] - NotEnoughMemory(usize), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } + +/// Create a post_processor for the LlamaTokenizer +fn create_post_processor( + tokenizer: &Tokenizer, + tokenizer_config: &HubTokenizerConfig, +) -> Result { + let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); + let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); + + let bos_token = tokenizer_config.bos_token.as_ref(); + let eos_token = tokenizer_config.eos_token.as_ref(); + + if add_bos_token && bos_token.is_none() { + panic!("add_bos_token = true but bos_token is None"); + } + + if add_eos_token && eos_token.is_none() { + panic!("add_eos_token = true but eos_token is None"); + } + + let mut single = Vec::new(); + let mut pair = Vec::new(); + let mut special_tokens = Vec::new(); + + if add_bos_token { + if let Some(bos) = bos_token { + let bos_token_id = tokenizer + .token_to_id(bos.as_str()) + .expect("Should have found the bos token id"); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); + } + } + + single.push("$A:0".to_string()); + pair.push("$A:0".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + let eos_token_id = tokenizer + .token_to_id(eos.as_str()) + .expect("Should have found the eos token id"); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); + } + } + + if add_bos_token { + if let Some(bos) = bos_token { + pair.push(format!("{}:1", bos.as_str())); + } + } + + pair.push("$B:1".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + pair.push(format!("{}:1", eos.as_str())); + } + } + + let post_processor = TemplateProcessing::builder() + .try_single(single)? + .try_pair(pair)? + .special_tokens(special_tokens) + .build()?; + + Ok(post_processor) +} diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs new file mode 100644 index 000000000..0282ac634 --- /dev/null +++ b/router/src/usage_stats.rs @@ -0,0 +1,360 @@ +use crate::config::Config; +use clap::ValueEnum; +use csv::ReaderBuilder; +use reqwest::header::HeaderMap; +use serde::Serialize; +use std::{ + fs::File, + io::{self, BufRead}, + path::Path, + process::Command, + time::Duration, +}; +use uuid::Uuid; + +const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; + +#[derive(Copy, Clone, Debug, Serialize, ValueEnum)] +pub enum UsageStatsLevel { + On, + NoStack, + Off, +} + +#[derive(Debug, Clone, Serialize)] +pub struct UserAgent { + pub uid: String, + pub args: Args, + pub env: Env, +} + +impl UserAgent { + pub fn new(reduced_args: Args) -> Self { + Self { + uid: Uuid::new_v4().to_string(), + args: reduced_args, + env: Env::new(), + } + } +} + +#[derive(Serialize, Debug)] +pub enum EventType { + Start, + Stop, + Error, +} + +#[derive(Debug, Serialize)] +pub struct UsageStatsEvent { + user_agent: UserAgent, + event_type: EventType, + #[serde(skip_serializing_if = "Option::is_none")] + error_reason: Option, +} + +impl UsageStatsEvent { + pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option) -> Self { + Self { + user_agent, + event_type, + error_reason, + } + } + pub async fn send(&self) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + let body = serde_json::to_string(&self).unwrap(); + let client = reqwest::Client::new(); + let _ = client + .post(TELEMETRY_URL) + .headers(headers) + .body(body) + .timeout(Duration::from_secs(5)) + .send() + .await; + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Args { + model_config: Option, + tokenizer_class: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + usage_stats_level: UsageStatsLevel, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub fn new( + model_config: Option, + tokenizer_class: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + usage_stats_level: UsageStatsLevel, + ) -> Self { + Self { + model_config, + tokenizer_class, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats_level, + } + } +} + +/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency +#[derive(Serialize, Debug, Clone)] +pub struct Env { + git_sha: &'static str, + docker_label: &'static str, + nvidia_info: Option>, + xpu_info: Option>, + system_env: SystemInfo, +} + +#[derive(Debug, Serialize, Clone)] +struct NvidiaSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + pstate: String, + pcie_link_gen_max: String, + pcie_link_gen_current: String, + temperature_gpu: String, + utilization_gpu: String, + utilization_memory: String, + memory_total: String, + memory_free: String, + memory_used: String, + reset_status_reset_required: String, + reset_status_drain_and_reset_recommended: String, + compute_cap: String, + ecc_errors_corrected_volatile_total: String, + mig_mode_current: String, + power_draw_instant: String, + power_limit: String, +} + +impl NvidiaSmiInfo { + fn new() -> Option> { + let output = Command::new("nvidia-smi") + .args([ + "--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(NvidiaSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + pstate: record[3].to_string(), + pcie_link_gen_max: record[4].to_string(), + pcie_link_gen_current: record[5].to_string(), + temperature_gpu: record[6].to_string(), + utilization_gpu: record[7].to_string(), + utilization_memory: record[8].to_string(), + memory_total: record[9].to_string(), + memory_free: record[10].to_string(), + memory_used: record[11].to_string(), + reset_status_reset_required: record[12].to_string(), + reset_status_drain_and_reset_recommended: record[13].to_string(), + compute_cap: record[14].to_string(), + ecc_errors_corrected_volatile_total: record[15].to_string(), + mig_mode_current: record[16].to_string(), + power_draw_instant: record[17].to_string(), + power_limit: record[18].to_string(), + }); + } + + Some(infos) + } +} + +#[derive(Debug, Serialize, Clone)] +struct XpuSmiInfo { + device_id: usize, + gpu_utilization: f32, + gpu_power: f32, + gpu_core_temperature: f32, + gpu_memory_bandwidth_utilization: f32, +} + +impl XpuSmiInfo { + /// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format + fn new() -> Option> { + let output = Command::new("xpu-smi") + .args([ + "dump", "-d", "-1", "-m", + "0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization + "-n", "1", "-j", + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + let mut infos = Vec::new(); + + let json_data: serde_json::Value = match serde_json::from_str(&stdout) { + Ok(data) => data, + Err(_) => return None, + }; + + if let Some(metrics_data) = json_data.as_array() { + for entry in metrics_data { + let device_id = entry["deviceId"].as_u64()? as usize; + let gpu_utilization = entry["metrics"][0].as_f64()? as f32; + let gpu_power = entry["metrics"][1].as_f64()? as f32; + let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32; + let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32; + + infos.push(XpuSmiInfo { + device_id, + gpu_utilization, + gpu_power, + gpu_core_temperature, + gpu_memory_bandwidth_utilization, + }); + } + } + + Some(infos) + } +} + +#[derive(Serialize, Debug, Clone)] +pub struct SystemInfo { + cpu_count: usize, + cpu_type: String, + total_memory: u64, + architecture: String, + platform: String, +} + +impl SystemInfo { + fn new() -> Self { + let mut system = sysinfo::System::new_all(); + system.refresh_all(); + + let cpu_count = system.cpus().len(); + let cpu_type = system.cpus()[0].brand().to_string(); + let total_memory = system.total_memory(); + let architecture = std::env::consts::ARCH.to_string(); + let platform = format!( + "{}-{}-{}", + std::env::consts::OS, + std::env::consts::FAMILY, + std::env::consts::ARCH + ); + Self { + cpu_count, + cpu_type, + total_memory, + architecture, + platform, + } + } +} + +impl Default for Env { + fn default() -> Self { + Self::new() + } +} + +impl Env { + pub fn new() -> Self { + Self { + system_env: SystemInfo::new(), + nvidia_info: NvidiaSmiInfo::new(), + xpu_info: XpuSmiInfo::new(), + git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), + docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), + } + } +} + +pub fn is_container() -> io::Result { + let path = Path::new("/proc/self/cgroup"); + let file = File::open(path)?; + let reader = io::BufReader::new(file); + + for line in reader.lines() { + let line = line?; + // Check for common container runtimes + if line.contains("/docker/") + || line.contains("/docker-") + || line.contains("/kubepods/") + || line.contains("/kubepods-") + || line.contains("containerd") + || line.contains("crio") + || line.contains("podman") + { + return Ok(true); + } + } + Ok(false) +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 07ad14c9c..3d1a4103f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -5,13 +5,12 @@ use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, }; use base64::{engine::general_purpose::STANDARD, Engine}; -use image::{io::Reader as ImageReader, ImageFormat}; +use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; -use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -96,7 +95,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -122,7 +121,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -181,11 +180,7 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok(( - vec![Chunk::Text(inputs).into()], - input_length, - max_new_tokens, - )) + Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) } } @@ -353,6 +348,14 @@ impl Validation { .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + // The schema can be valid but lack properties. + // We need properties for the grammar to be successfully parsed in Python. + // Therefore, we must check and throw an error if properties are missing. + json.get("properties") + .ok_or(ValidationError::InvalidGrammar( + "Grammar must have a 'properties' field".to_string(), + ))?; + // Serialize json to string ValidGrammar::Json( serde_json::to_string(&json) @@ -581,7 +584,7 @@ fn prepare_input( tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, -) -> Result<(tokenizers::Encoding, Vec), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { @@ -593,16 +596,16 @@ fn prepare_input( let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); + input_chunks.push(Chunk::Image(Image { data, mimetype })); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); start = chunk_end; } if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..].to_string())); tokenizer_query.push_str(&inputs[start..]); } @@ -610,7 +613,7 @@ fn prepare_input( (tokenizer_query, input_chunks) } - _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), + _ => (inputs.clone(), vec![Chunk::Text(inputs)]), }; // Get the number of tokens in the input @@ -623,18 +626,51 @@ fn prepare_input( type TokenizerRequest = ( (String, Option), - oneshot::Sender), ValidationError>>, + oneshot::Sender), ValidationError>>, Span, ); +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Image { + pub data: Vec, + pub mimetype: String, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Chunk { + Text(String), + Image(Image), +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c { + Chunk::Text(text) => output.push_str(text), + Chunk::Image(Image { data, mimetype }) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + }); + output + } +} + #[derive(Debug, Clone)] -pub(crate) enum ValidGrammar { +pub enum ValidGrammar { Json(String), Regex(String), } #[derive(Debug, Clone)] -pub(crate) struct ValidParameters { +pub struct ValidParameters { /// / exponential scaling output probability distribution pub temperature: f32, /// / restricting to the k highest probability elements @@ -658,7 +694,7 @@ pub(crate) struct ValidParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidStoppingParameters { +pub struct ValidStoppingParameters { /// / Maximum number of generated tokens pub max_new_tokens: u32, /// / Optional stopping sequences @@ -669,8 +705,8 @@ pub(crate) struct ValidStoppingParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidGenerateRequest { - pub inputs: Vec, +pub struct ValidGenerateRequest { + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, @@ -742,6 +778,8 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), + #[error("{0} modality is not supported")] + UnsupportedModality(&'static str), } #[cfg(test)] diff --git a/server/Makefile b/server/Makefile index d701c8198..209fc44e4 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,6 +5,7 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica +include Makefile-fbgemm unit-tests: pytest -s -vv -m "not private" tests @@ -21,13 +22,15 @@ gen-server: install-server: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt - pip install -e ".[bnb, accelerate, quantize, peft, outlines]" + pip install -e ".[accelerate, quantize, peft, outlines]" install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm + pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm @@ -35,5 +38,6 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes --with cuda + poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes + poetry export -o requirements_intel.txt --without-hashes diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm new file mode 100644 index 000000000..575265775 --- /dev/null +++ b/server/Makefile-fbgemm @@ -0,0 +1,13 @@ +fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856 + +build-fbgemm: + git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index ba90a74d7..dbddd0f41 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 build-flash-attention-v2-cuda: diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2f2b5ef68..f1f805290 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,14 +1,14 @@ -commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ git clone https://github.com/Narsil/vllm.git vllm; \ fi - cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build install-vllm-cuda: build-vllm-cuda - cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/marlin/COPYRIGHT b/server/marlin/COPYRIGHT deleted file mode 100644 index 69f3b8e64..000000000 --- a/server/marlin/COPYRIGHT +++ /dev/null @@ -1,20 +0,0 @@ -These kernels were vendored from VLLM. The Marlin kernels were developed -by Elias Frantar and extended by Neural Magic. - ---- - -Copyright (C) Marlin.2024 Elias Frantar -Modified by Neural Magic -Copyright 2024 The vLLM team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi deleted file mode 100644 index 663984d01..000000000 --- a/server/marlin/marlin_kernels/__init__.pyi +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports converted GPTQ kernels. - """ - ... - -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. This is an extension of - `marlin_gemm` that supports 2:4 sparsity. - """ - ... - -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - """Repack GPTQ parameters for Marlin kernels.""" - ... - -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - """ - Matrix multiplication using Marlin kernels. - """ - ... diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp deleted file mode 100644 index 37eccef66..000000000 --- a/server/marlin/marlin_kernels/ext.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include - -#include "ext.hh" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "Marlin gemm with GPTQ compatibility"); - m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); - m.def("gptq_marlin_repack", &gptq_marlin_repack, - "Repack GPTQ parameters for Marlin"); - m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); -} diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh deleted file mode 100644 index d1caaab7c..000000000 --- a/server/marlin/marlin_kernels/ext.hh +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cu b/server/marlin/marlin_kernels/gptq_marlin.cu deleted file mode 100644 index 0beb9de14..000000000 --- a/server/marlin/marlin_kernels/gptq_marlin.cu +++ /dev/null @@ -1,1870 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "gptq_marlin.cuh" -#include "gptq_marlin_dtypes.cuh" - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace gptq_marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) {} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -template -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - - scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / default_threads; - int rest = size_k % default_threads; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += default_threads; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - div_ceil(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { - res = __hmul2(res, s[0]); - } - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = div_ceil(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, - void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, int max_par) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - const int* g_idx_ptr = (const int*)g_idx; - const int* perm_ptr = (const int*)perm; - int4* a_tmp_ptr = (int4*)a_tmp; - - int* locks = (int*)workspace; - - if (has_act_order) { - // Permute A columns - int block_rows = div_ceil(prob_m, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by having - // a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - // Define kernel configurations - if (false) { - } - CALL_IF(4, 32, 2, 256) - CALL_IF(4, 16, 4, 256) - CALL_IF(4, 8, 8, 256) - CALL_IF(4, 8, 4, 128) - CALL_IF(4, 4, 8, 128) - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", gptq_marlin::tile_size); - int actual_size_n = - (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); - TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Verify g_idx and perm - TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || - (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = ", g_idx.size(0), - " and perm.size(0) = ", perm.size(0), - ", where size_k = ", size_k); - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(0) != 0; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - num_groups = b_scales.size(0); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify workspace size - TORCH_CHECK( - size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); - int min_workspace_size = - (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, gptq_marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, - is_k_full, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - gptq_marlin::max_par); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin.cuh b/server/marlin/marlin_kernels/gptq_marlin.cuh deleted file mode 100644 index 42af44951..000000000 --- a/server/marlin/marlin_kernels/gptq_marlin.cuh +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace gptq_marlin { - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int default_threads = 256; - -static constexpr int pipe_stages = - 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -#endif - -} // namespace gptq_marlin diff --git a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh b/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh deleted file mode 100644 index ca1b7099d..000000000 --- a/server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh +++ /dev/null @@ -1,77 +0,0 @@ - -#ifndef _data_types_cuh -#define _data_types_cuh -#include "gptq_marlin.cuh" -#include -#include - -namespace gptq_marlin { - -template -class ScalarType {}; - -template <> -class ScalarType { - public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - - static __device__ float inline num2float(const half x) { - return __half2float(x); - } - - static __device__ half2 inline num2num2(const half x) { - return __half2half2(x); - } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { - return __halves2half2(x1, x2); - } - - static __host__ __device__ half inline float2num(const float x) { - return __float2half(x); - } -}; - -template <> -class ScalarType { - public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; - - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { - return __bfloat162float(x); - } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { - return __bfloat162bfloat162(x); - } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, - const nv_bfloat16 x2) { - return __halves2bfloat162(x1, x2); - } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { - return __float2bfloat16(x); - } -#endif -}; - -} // namespace gptq_marlin - -#endif diff --git a/server/marlin/marlin_kernels/gptq_marlin_repack.cu b/server/marlin/marlin_kernels/gptq_marlin_repack.cu deleted file mode 100644 index 4adc158eb..000000000 --- a/server/marlin/marlin_kernels/gptq_marlin_repack.cu +++ /dev/null @@ -1,350 +0,0 @@ -#include "gptq_marlin.cuh" - -namespace gptq_marlin { - -static constexpr int repack_stages = 8; - -static constexpr int repack_threads = 256; - -static constexpr int tile_k_size = tile_size; -static constexpr int tile_n_size = tile_k_size * 4; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template -__global__ void marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace gptq_marlin - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -template -__global__ void marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { - constexpr int pack_factor = 32 / num_bits; - - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; - int block_k_tiles = div_ceil(k_tiles, gridDim.x); - - int start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) { - return; - } - - int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - extern __shared__ int4 sh[]; - - constexpr int perm_size = tile_k_size / 4; - - int4* sh_perm_ptr = sh; - int4* sh_pipe_ptr = sh_perm_ptr; - if constexpr (has_perm) { - sh_pipe_ptr += perm_size; - } - - constexpr int tile_ints = tile_k_size / pack_factor; - - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; - constexpr int stage_size = stage_k_threads * stage_n_threads; - - auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; - - int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); - - if (threadIdx.x < perm_size) { - sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; - } - __syncthreads(); - }; - - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - cp_async_fence(); - return; - } - - int first_n = n_tile_id * tile_n_size; - - int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; - - if constexpr (has_perm) { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - uint32_t const* sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); - - int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor; - - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( - b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); - } - - } else { - if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor; - - cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); - } - } - - cp_async_fence(); - }; - - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { - if (n_tile_id >= n_tiles) { - return; - } - - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; - - if (warp_id >= 4) { - return; - } - - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; - - constexpr int tc_offsets[4] = {0, 1, 8, 9}; - - int cur_n = warp_id * 16 + tc_col; - - constexpr int sh_stride = 64; - constexpr uint32_t mask = (1 << num_bits) - 1; - - int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - - uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - - uint32_t vals[8]; - - if constexpr (has_perm) { - for (int i = 0; i < 4; i++) { - int k_idx = tc_row + tc_offsets[i]; - - uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor; - - uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; - - uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; - - vals[i] = b1_cur_val; - vals[4 + i] = b2_cur_val; - } - - } else { - uint32_t b1_vals[tile_ints]; - uint32_t b2_vals[tile_ints]; - - #pragma unroll - for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; - } - - #pragma unroll - for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; - int cur_int = cur_elem / pack_factor; - int cur_pos = cur_elem % pack_factor; - - vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; - } - } - - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - - uint32_t res = 0; - #pragma unroll - for (int i = 0; i < 8; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } - - out_ptr[out_offset + th_id * 4 + warp_id] = res; - - } else { - constexpr int pack_idx[4] = {0, 2, 1, 3}; - - uint32_t res1 = 0; - uint32_t res2 = 0; - #pragma unroll - for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } - - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } - }; - - auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } - - wait_for_stage(); - }; - #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { - int n_tile_id = 0; - - if constexpr (has_perm) { - load_perm_to_shared(k_tile_id); - } - - start_pipes(k_tile_id, n_tile_id); - - while (n_tile_id < n_tiles) { - #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, - n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); - } - n_tile_id += repack_stages; - } - } -} - -} // namespace gptq_marlin - - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - // Verify compatibility with marlin tile of 16x64 - TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, - " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); - TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, - " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); - - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / num_bits; - - // Verify B - TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", pack_factor = ", pack_factor); - TORCH_CHECK(b_q_weight.size(1) == size_n, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not size_n = ", size_n); - - // Verify device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); - TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - torch::Tensor out = - torch::empty({size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / pack_factor}, - options); - - // Detect if there is act_order - bool has_perm = perm.size(0) != 0; - - // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); - uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - int dev = b_q_weight.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (false) { - } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) - else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); - } - - return out; -} - -#endif diff --git a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu b/server/marlin/marlin_kernels/marlin_cuda_kernel.cu deleted file mode 100644 index d124c0149..000000000 --- a/server/marlin/marlin_kernels/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1136 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin::tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(marlin::tile_size)); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin::tile_size)); - - int actual_size_n = - (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK( - size_n % marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, marlin::max_par); - - return c; -} diff --git a/server/marlin/marlin_kernels/sparse/common/base.h b/server/marlin/marlin_kernels/sparse/common/base.h deleted file mode 100644 index 16018d331..000000000 --- a/server/marlin/marlin_kernels/sparse/common/base.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace marlin_24 { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -template -struct ShapeBase { - static constexpr int M = M_, N = N_, K = K_; -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragM = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mem.h b/server/marlin/marlin_kernels/sparse/common/mem.h deleted file mode 100644 index 83e3578d2..000000000 --- a/server/marlin/marlin_kernels/sparse/common/mem.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" - -namespace marlin_24 { -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void* smem_ptr, - const void* glob_ptr, - bool pred = true, - const bool zfill = false) { - const int BYTES = 16; - int src_in_bytes = (zfill ? 0 : BYTES); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); -} - -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_m); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/common/mma.h b/server/marlin/marlin_kernels/sparse/common/mma.h deleted file mode 100644 index b26505f77..000000000 --- a/server/marlin/marlin_kernels/sparse/common/mma.h +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "base.h" -#include - -namespace marlin_24 { - -// On CUDA earlier than 12.5, the ordered_metadata version of this instruction -// is not supported. On later versions of CUDA the version without ordered -// metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially -// | reduced performance on some future architectures -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 - #define MMA_SP_INST \ - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#else - #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " -#endif - -// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, - const FragA& frag_b, FragC& frag_c, FragM& frag_m, - const int psel) { - const uint32_t* a0 = reinterpret_cast(&a_frag0); - const uint32_t* a1 = reinterpret_cast(&a_frag1); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* e = reinterpret_cast(&frag_m); - - float* c = reinterpret_cast(&frag_c); - if (psel == 0) { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } else { - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); - } -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, - float c3) { - uint2 r; - asm("{\n\t" - ".reg .f16 a, b, c, d; \n\t" - "cvt.rn.f16.f32 a, %2; \n\t" - "cvt.rn.f16.f32 b, %3; \n\t" - "cvt.rn.f16.f32 c, %4; \n\t" - "cvt.rn.f16.f32 d, %5; \n\t" - "mov.b32 %0, {a, b}; \n\t" - "mov.b32 %1, {c, d}; \n\t" - "}" - : "=r"(r.x), "=r"(r.y) - : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - return r; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_8bit(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, - FragS& s0, float* c4, float* c5, float* c6, - float* c7, FragS& s1) { - *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); - *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); - *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); - *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); - - *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); - *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); - *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); - *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); -} - -} // namespace marlin_24 diff --git a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu b/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu deleted file mode 100644 index b5effc305..000000000 --- a/server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu +++ /dev/null @@ -1,1125 +0,0 @@ -/* - * Notice: This file was modified by Neuralmagic inc to include 8-bit support - * - * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All - * Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -#else - - #include "common/mem.h" - #include "common/mma.h" - -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_24 { - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -static constexpr int THREADS = 256; -static constexpr int STAGES = 4; - -static constexpr int min_thread_n = 128; - -static constexpr int tile_size = 16; -static constexpr int max_par = 64; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - // number of thread_k_blocks in k-dim - int k_tiles = prob_k / 32 / thread_k_blocks; - // number of thread_n_blocks in n-dim - int n_tiles = prob_n / 16 / thread_n_blocks; - // iters needed to cover all slices - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - // number of threadblock tiles in the current slice - int slice_iters; - // total number of active threadblocks in the current slice - int slice_count = 0; - // index of threadblock in current slice; numbered bottom to top - int slice_idx; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 32 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads //RLC: 2 * #warps k-dim - constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - constexpr int pack_factor = 32 / num_bits; - - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 - constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp - int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; - int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); - constexpr int m_sh_wr_delta = threads / 2; - constexpr int m_sh_rd_delta = threads / 2; - constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; - constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + - (threadIdx.x % (m_sh_stride)); - m_gl_rd += (m_sh_stride)*slice_col; - m_gl_rd += m_gl_rd_delta_o * slice_row; - int m_sh_wr = threadIdx.x; - int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; - - int s_gl_rd; - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } else { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) { - a_sh_rd_trans[0][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - a_sh_rd_trans[1][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); - } - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4* meta_ptr[m_sh_iters]; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_m = sh_s + (stages * s_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks][2]; - I4 frag_b_quant[2][b_thread_vecs]; - FragM frag_m[2][2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - int4* sh_meta_stage = sh_m + m_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) { - if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); - meta_ptr[i] += m_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if (group_blocks != -1) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - ldsm4(frag_a[k % 2][i][0], - &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); - ldsm4(frag_a[k % 2][i][1], - &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); - } - - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - - // Load meta with ldsm4 - int4* sh_m_stage = sh_m + m_sh_stage * pipe; - ldsm4_m(frag_m[k % 2][0], - &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - - } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } - - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], - frag_m[k % 2][j / 2], j % 2); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; - int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) - int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int col = 2 * ((threadIdx.x % 32) % 4); - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); - } - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2]); - } - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: - - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - - int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) - - constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: - int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + - (threadIdx.x % (2 * 2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, - float c4, float c5, float c6, float c7, FragS& s1) { - uint2 res[2]; - res[0] = to_half4(c0, c1, c2, c3); - res[1] = to_half4(c4, c5, c6, c7); - half2* tmp = (half2*)&res; - // for per-column quantization we finally apply the scale here - if constexpr (group_blocks == -1 && num_bits == 4) { - tmp[0] = __hmul2(tmp[0], s0[0]); - tmp[1] = __hmul2(tmp[1], s0[1]); - tmp[2] = __hmul2(tmp[2], s1[0]); - tmp[3] = __hmul2(tmp[3], s1[1]); - } - ((int4*)sh)[idx] = *((int4*)&res[0]); - }; - - // RLC: only warp 0 and 1 baseline example - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - int wr = c_sh_wr; - write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], - frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], - frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], - frag_s[0][2]); - write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], - frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], - frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], - frag_c[i][3][0][3], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], - frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], - frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], - frag_c[i][3][1][2], frag_s[0][2]); - write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], - frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], - frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], - frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); - - c_sh_wr += 8 * c_sh_stride_2; - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - matmul(pipe); - wait_for_stage(); - - fetch_to_registers(pipe + 1, (pipe + 1) % stages); - - pipe++; - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } else { - if (last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - } - } - thread_block_reduce(); - - if constexpr (group_blocks == -1) { - if constexpr (num_bits == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (group_blocks == -1 && num_bits == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], - &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], - &frag_c[i][0][0][2], &frag_c[i][1][0][2], - &frag_c[i][2][0][2], &frag_c[i][3][0][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], - &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], - &frag_c[i][0][0][3], &frag_c[i][1][0][3], - &frag_c[i][2][0][3], &frag_c[i][3][0][3], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], - &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], - &frag_c[i][0][1][2], &frag_c[i][1][1][2], - &frag_c[i][2][1][2], &frag_c[i][3][1][2], - frag_s[0][2]); - - scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], - &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], - &frag_c[i][0][1][3], &frag_c[i][1][1][3], - &frag_c[i][2][1][3], &frag_c[i][3][1][3], - frag_s[0][2]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#endif - -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ - } - -void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, - void* s, int prob_m, int prob_n, int prob_k, - void* workspace, int num_bits, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16) { - int tot_n = prob_n; - int tot_n_blocks = ceildiv(tot_n, 16); - int pad = 16 * tot_n_blocks - tot_n; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - TORCH_CHECK(sms > 0); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - if (thread_k == -1 || thread_m == -1) { - if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important - // than better compute utilization - thread_k = 128; - thread_m = 128; - } else if (prob_n <= 256) { - thread_k = 64; - thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; - } - } - - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction - int thread_m_blocks = thread_m / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, - " is not divisible by thread_m = ", thread_m); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, - " is not divisible by group_blocks = ", group_blocks); - } - - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - const int4* meta_ptr = (const int4*)meta; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - constexpr int max_m_blocks = 4; - - int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { - int thread_n_blocks = tot_n_blocks - i; - prob_n = tot_n - 16 * i; - int par = 1; - if (thread_n_blocks > max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); - if (par > max_par) par = max_par; - prob_n = (max_m_blocks * 16) * par; - i += max_m_blocks * (par - 1); - thread_n_blocks = max_m_blocks; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - - // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group - // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 16, 2, 2, 4) - CALL_IF_2_4(4, 16, 3, 2, -1) - CALL_IF_2_4(4, 16, 3, 2, 4) - CALL_IF_2_4(4, 16, 4, 2, -1) - CALL_IF_2_4(4, 16, 4, 2, 4) - - CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(4, 32, 2, 1, 4) - CALL_IF_2_4(4, 32, 3, 1, -1) - CALL_IF_2_4(4, 32, 3, 1, 4) - CALL_IF_2_4(4, 32, 4, 1, -1) - CALL_IF_2_4(4, 32, 4, 1, 4) - - // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 16, 2, 2, 4) - CALL_IF_2_4(8, 16, 3, 2, -1) - CALL_IF_2_4(8, 16, 3, 2, 4) - CALL_IF_2_4(8, 16, 4, 2, -1) - CALL_IF_2_4(8, 16, 4, 2, 4) - - CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 - CALL_IF_2_4(8, 32, 2, 1, 4) - CALL_IF_2_4(8, 32, 3, 1, -1) - CALL_IF_2_4(8, 32, 3, 1, 4) - CALL_IF_2_4(8, 32, 4, 1, -1) - CALL_IF_2_4(8, 32, 4, 1, 4) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; - } -} - -} // namespace marlin_24 - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_24::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_24::tile_size)); - TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_24::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_24::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_24::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify meta - TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, - "b_meta.size(0) = ", b_meta.size(0), - " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); - TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), - " is not size_n * 2 = ", size_n * 2); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify b_meta device and strides - TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); - TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - int thread_k = -1; - int thread_m = -1; - int sms = -1; - int max_par = marlin_24::max_par; - - int groupsize = -1; - if (b_scales.size(0) > 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 - } - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 64, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_24::min_thread_n)); - int min_workspace_size = - (size_n / marlin_24::min_thread_n) * marlin_24::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_24::marlin_cuda_2_4( - a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); - - return c; -} diff --git a/server/marlin/setup.py b/server/marlin/setup.py deleted file mode 100644 index aed84e9eb..000000000 --- a/server/marlin/setup.py +++ /dev/null @@ -1,22 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -extra_compile_args = [] - -setup( - name="marlin_kernels", - ext_modules=[ - CUDAExtension( - name="marlin_kernels", - sources=[ - "marlin_kernels/gptq_marlin.cu", - "marlin_kernels/gptq_marlin_repack.cu", - "marlin_kernels/marlin_cuda_kernel.cu", - "marlin_kernels/sparse/marlin_24_cuda_kernel.cu", - "marlin_kernels/ext.cpp", - ], - extra_compile_args=extra_compile_args, - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/server/poetry.lock b/server/poetry.lock index 4984978a3..5072aa0bd 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -202,13 +202,13 @@ test = ["scipy"] [[package]] name = "certifi" -version = "2024.6.2" +version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, - {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, + {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, + {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] [[package]] @@ -348,45 +348,47 @@ files = [ [[package]] name = "datasets" -version = "2.14.4" +version = "2.20.0" description = "HuggingFace community-driven open-source library of datasets" optional = true python-versions = ">=3.8.0" files = [ - {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, - {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, + {file = "datasets-2.20.0-py3-none-any.whl", hash = "sha256:76ac02e3bdfff824492e20678f0b6b1b6d080515957fe834b00c2ba8d6b18e5e"}, + {file = "datasets-2.20.0.tar.gz", hash = "sha256:3c4dbcd27e0f642b9d41d20ff2efa721a5e04b32b2ca4009e0fc9139e324553f"}, ] [package.dependencies] aiohttp = "*" -dill = ">=0.3.0,<0.3.8" -fsspec = {version = ">=2021.11.1", extras = ["http"]} -huggingface-hub = ">=0.14.0,<1.0.0" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.5.0", extras = ["http"]} +huggingface-hub = ">=0.21.2" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" -pyarrow = ">=8.0.0" +pyarrow = ">=15.0.0" +pyarrow-hotfix = "*" pyyaml = ">=5.1" -requests = ">=2.19.0" -tqdm = ">=4.62.1" +requests = ">=2.32.2" +tqdm = ">=4.66.3" xxhash = "*" [package.extras] -apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] -docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] -jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +dev = ["Pillow (>=9.4.0)", "absl-py", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] -quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] -tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] -tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +tensorflow = ["tensorflow (>=2.6.0)"] +tensorflow-gpu = ["tensorflow (>=2.6.0)"] +tests = ["Pillow (>=9.4.0)", "absl-py", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] -vision = ["Pillow (>=6.2.1)"] +vision = ["Pillow (>=9.4.0)"] [[package]] name = "deprecated" @@ -407,17 +409,18 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] [[package]] name = "dill" -version = "0.3.7" +version = "0.3.8" description = "serialize all of Python" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, - {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, ] [package.extras] graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] [[package]] name = "diskcache" @@ -443,13 +446,13 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.2.1" +version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, - {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] [package.extras] @@ -457,18 +460,18 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -559,13 +562,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.6.0" +version = "2024.5.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"}, - {file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"}, + {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, + {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, ] [package.dependencies] @@ -577,7 +580,6 @@ adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] dev = ["pre-commit", "ruff"] -doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -601,17 +603,17 @@ tqdm = ["tqdm"] [[package]] name = "googleapis-common-protos" -version = "1.63.1" +version = "1.63.2" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" files = [ - {file = "googleapis-common-protos-1.63.1.tar.gz", hash = "sha256:c6442f7a0a6b2a80369457d79e6672bb7dcbaab88e0848302497e3ec80780a6a"}, - {file = "googleapis_common_protos-1.63.1-py2.py3-none-any.whl", hash = "sha256:0e1c2cdfcbc354b76e4a211a35ea35d6926a835cba1377073c4861db904a1877"}, + {file = "googleapis-common-protos-1.63.2.tar.gz", hash = "sha256:27c5abdffc4911f28101e635de1533fb4cfd2c37fbaa9174587c799fac90aa87"}, + {file = "googleapis_common_protos-1.63.2-py2.py3-none-any.whl", hash = "sha256:27a2499c7e8aff199665b22741997e485eccc8645aa9176c7c988e6fae507945"}, ] [package.dependencies] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] @@ -635,61 +637,61 @@ testing = ["protobuf (>=4.21.9)"] [[package]] name = "grpcio" -version = "1.64.1" +version = "1.65.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.64.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:55697ecec192bc3f2f3cc13a295ab670f51de29884ca9ae6cd6247df55df2502"}, - {file = "grpcio-1.64.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3b64ae304c175671efdaa7ec9ae2cc36996b681eb63ca39c464958396697daff"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bac71b4b28bc9af61efcdc7630b166440bbfbaa80940c9a697271b5e1dabbc61"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c024ffc22d6dc59000faf8ad781696d81e8e38f4078cb0f2630b4a3cf231a90"}, - {file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cd5c1325f6808b8ae31657d281aadb2a51ac11ab081ae335f4f7fc44c1721d"}, - {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0a2813093ddb27418a4c99f9b1c223fab0b053157176a64cc9db0f4557b69bd9"}, - {file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2981c7365a9353f9b5c864595c510c983251b1ab403e05b1ccc70a3d9541a73b"}, - {file = "grpcio-1.64.1-cp310-cp310-win32.whl", hash = "sha256:1262402af5a511c245c3ae918167eca57342c72320dffae5d9b51840c4b2f86d"}, - {file = "grpcio-1.64.1-cp310-cp310-win_amd64.whl", hash = "sha256:19264fc964576ddb065368cae953f8d0514ecc6cb3da8903766d9fb9d4554c33"}, - {file = "grpcio-1.64.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:58b1041e7c870bb30ee41d3090cbd6f0851f30ae4eb68228955d973d3efa2e61"}, - {file = "grpcio-1.64.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bbc5b1d78a7822b0a84c6f8917faa986c1a744e65d762ef6d8be9d75677af2ca"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5841dd1f284bd1b3d8a6eca3a7f062b06f1eec09b184397e1d1d43447e89a7ae"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8caee47e970b92b3dd948371230fcceb80d3f2277b3bf7fbd7c0564e7d39068e"}, - {file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73819689c169417a4f978e562d24f2def2be75739c4bed1992435d007819da1b"}, - {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6503b64c8b2dfad299749cad1b595c650c91e5b2c8a1b775380fcf8d2cbba1e9"}, - {file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1de403fc1305fd96cfa75e83be3dee8538f2413a6b1685b8452301c7ba33c294"}, - {file = "grpcio-1.64.1-cp311-cp311-win32.whl", hash = "sha256:d4d29cc612e1332237877dfa7fe687157973aab1d63bd0f84cf06692f04c0367"}, - {file = "grpcio-1.64.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e56462b05a6f860b72f0fa50dca06d5b26543a4e88d0396259a07dc30f4e5aa"}, - {file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"}, - {file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"}, - {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"}, - {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"}, - {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"}, - {file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"}, - {file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"}, - {file = "grpcio-1.64.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:a011ac6c03cfe162ff2b727bcb530567826cec85eb8d4ad2bfb4bd023287a52d"}, - {file = "grpcio-1.64.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4d6dab6124225496010bd22690f2d9bd35c7cbb267b3f14e7a3eb05c911325d4"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:a5e771d0252e871ce194d0fdcafd13971f1aae0ddacc5f25615030d5df55c3a2"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3c1b90ab93fed424e454e93c0ed0b9d552bdf1b0929712b094f5ecfe7a23ad"}, - {file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20405cb8b13fd779135df23fabadc53b86522d0f1cba8cca0e87968587f50650"}, - {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0cc79c982ccb2feec8aad0e8fb0d168bcbca85bc77b080d0d3c5f2f15c24ea8f"}, - {file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a3a035c37ce7565b8f4f35ff683a4db34d24e53dc487e47438e434eb3f701b2a"}, - {file = "grpcio-1.64.1-cp38-cp38-win32.whl", hash = "sha256:1257b76748612aca0f89beec7fa0615727fd6f2a1ad580a9638816a4b2eb18fd"}, - {file = "grpcio-1.64.1-cp38-cp38-win_amd64.whl", hash = "sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122"}, - {file = "grpcio-1.64.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:75dbbf415026d2862192fe1b28d71f209e2fd87079d98470db90bebe57b33179"}, - {file = "grpcio-1.64.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e3d9f8d1221baa0ced7ec7322a981e28deb23749c76eeeb3d33e18b72935ab62"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5f8b75f64d5d324c565b263c67dbe4f0af595635bbdd93bb1a88189fc62ed2e5"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c84ad903d0d94311a2b7eea608da163dace97c5fe9412ea311e72c3684925602"}, - {file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940e3ec884520155f68a3b712d045e077d61c520a195d1a5932c531f11883489"}, - {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309"}, - {file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac15b6c2c80a4d1338b04d42a02d376a53395ddf0ec9ab157cbaf44191f3ffdd"}, - {file = "grpcio-1.64.1-cp39-cp39-win32.whl", hash = "sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040"}, - {file = "grpcio-1.64.1-cp39-cp39-win_amd64.whl", hash = "sha256:ed6091fa0adcc7e4ff944090cf203a52da35c37a130efa564ded02b7aff63bcd"}, - {file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"}, + {file = "grpcio-1.65.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:3dc5f928815b8972fb83b78d8db5039559f39e004ec93ebac316403fe031a062"}, + {file = "grpcio-1.65.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:8333ca46053c35484c9f2f7e8d8ec98c1383a8675a449163cea31a2076d93de8"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:7af64838b6e615fff0ec711960ed9b6ee83086edfa8c32670eafb736f169d719"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbb64b4166362d9326f7efbf75b1c72106c1aa87f13a8c8b56a1224fac152f5c"}, + {file = "grpcio-1.65.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8422dc13ad93ec8caa2612b5032a2b9cd6421c13ed87f54db4a3a2c93afaf77"}, + {file = "grpcio-1.65.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4effc0562b6c65d4add6a873ca132e46ba5e5a46f07c93502c37a9ae7f043857"}, + {file = "grpcio-1.65.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a6c71575a2fedf259724981fd73a18906513d2f306169c46262a5bae956e6364"}, + {file = "grpcio-1.65.1-cp310-cp310-win32.whl", hash = "sha256:34966cf526ef0ea616e008d40d989463e3db157abb213b2f20c6ce0ae7928875"}, + {file = "grpcio-1.65.1-cp310-cp310-win_amd64.whl", hash = "sha256:ca931de5dd6d9eb94ff19a2c9434b23923bce6f767179fef04dfa991f282eaad"}, + {file = "grpcio-1.65.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:bbb46330cc643ecf10bd9bd4ca8e7419a14b6b9dedd05f671c90fb2c813c6037"}, + {file = "grpcio-1.65.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d827a6fb9215b961eb73459ad7977edb9e748b23e3407d21c845d1d8ef6597e5"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:6e71aed8835f8d9fbcb84babc93a9da95955d1685021cceb7089f4f1e717d719"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a1c84560b3b2d34695c9ba53ab0264e2802721c530678a8f0a227951f453462"}, + {file = "grpcio-1.65.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27adee2338d697e71143ed147fe286c05810965d5d30ec14dd09c22479bfe48a"}, + {file = "grpcio-1.65.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f62652ddcadc75d0e7aa629e96bb61658f85a993e748333715b4ab667192e4e8"}, + {file = "grpcio-1.65.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:71a05fd814700dd9cb7d9a507f2f6a1ef85866733ccaf557eedacec32d65e4c2"}, + {file = "grpcio-1.65.1-cp311-cp311-win32.whl", hash = "sha256:b590f1ad056294dfaeac0b7e1b71d3d5ace638d8dd1f1147ce4bd13458783ba8"}, + {file = "grpcio-1.65.1-cp311-cp311-win_amd64.whl", hash = "sha256:12e9bdf3b5fd48e5fbe5b3da382ad8f97c08b47969f3cca81dd9b36b86ed39e2"}, + {file = "grpcio-1.65.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:54cb822e177374b318b233e54b6856c692c24cdbd5a3ba5335f18a47396bac8f"}, + {file = "grpcio-1.65.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:aaf3c54419a28d45bd1681372029f40e5bfb58e5265e3882eaf21e4a5f81a119"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:557de35bdfbe8bafea0a003dbd0f4da6d89223ac6c4c7549d78e20f92ead95d9"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8bfd95ef3b097f0cc86ade54eafefa1c8ed623aa01a26fbbdcd1a3650494dd11"}, + {file = "grpcio-1.65.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e6a8f3d6c41e6b642870afe6cafbaf7b61c57317f9ec66d0efdaf19db992b90"}, + {file = "grpcio-1.65.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1faaf7355ceed07ceaef0b9dcefa4c98daf1dd8840ed75c2de128c3f4a4d859d"}, + {file = "grpcio-1.65.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:60f1f38eed830488ad2a1b11579ef0f345ff16fffdad1d24d9fbc97ba31804ff"}, + {file = "grpcio-1.65.1-cp312-cp312-win32.whl", hash = "sha256:e75acfa52daf5ea0712e8aa82f0003bba964de7ae22c26d208cbd7bc08500177"}, + {file = "grpcio-1.65.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff5a84907e51924973aa05ed8759210d8cdae7ffcf9e44fd17646cf4a902df59"}, + {file = "grpcio-1.65.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:1fbd6331f18c3acd7e09d17fd840c096f56eaf0ef830fbd50af45ae9dc8dfd83"}, + {file = "grpcio-1.65.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:de5b6be29116e094c5ef9d9e4252e7eb143e3d5f6bd6d50a78075553ab4930b0"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e4a3cdba62b2d6aeae6027ae65f350de6dc082b72e6215eccf82628e79efe9ba"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941c4869aa229d88706b78187d60d66aca77fe5c32518b79e3c3e03fc26109a2"}, + {file = "grpcio-1.65.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f40cebe5edb518d78b8131e87cb83b3ee688984de38a232024b9b44e74ee53d3"}, + {file = "grpcio-1.65.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2ca684ba331fb249d8a1ce88db5394e70dbcd96e58d8c4b7e0d7b141a453dce9"}, + {file = "grpcio-1.65.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8558f0083ddaf5de64a59c790bffd7568e353914c0c551eae2955f54ee4b857f"}, + {file = "grpcio-1.65.1-cp38-cp38-win32.whl", hash = "sha256:8d8143a3e3966f85dce6c5cc45387ec36552174ba5712c5dc6fcc0898fb324c0"}, + {file = "grpcio-1.65.1-cp38-cp38-win_amd64.whl", hash = "sha256:76e81a86424d6ca1ce7c16b15bdd6a964a42b40544bf796a48da241fdaf61153"}, + {file = "grpcio-1.65.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb5175f45c980ff418998723ea1b3869cce3766d2ab4e4916fbd3cedbc9d0ed3"}, + {file = "grpcio-1.65.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b12c1aa7b95abe73b3e04e052c8b362655b41c7798da69f1eaf8d186c7d204df"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:3019fb50128b21a5e018d89569ffaaaa361680e1346c2f261bb84a91082eb3d3"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ae15275ed98ea267f64ee9ddedf8ecd5306a5b5bb87972a48bfe24af24153e8"}, + {file = "grpcio-1.65.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f096ffb881f37e8d4f958b63c74bfc400c7cebd7a944b027357cd2fb8d91a57"}, + {file = "grpcio-1.65.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2f56b5a68fdcf17a0a1d524bf177218c3c69b3947cb239ea222c6f1867c3ab68"}, + {file = "grpcio-1.65.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:941596d419b9736ab548aa0feb5bbba922f98872668847bf0720b42d1d227b9e"}, + {file = "grpcio-1.65.1-cp39-cp39-win32.whl", hash = "sha256:5fd7337a823b890215f07d429f4f193d24b80d62a5485cf88ee06648591a0c57"}, + {file = "grpcio-1.65.1-cp39-cp39-win_amd64.whl", hash = "sha256:1bceeec568372cbebf554eae1b436b06c2ff24cfaf04afade729fb9035408c6c"}, + {file = "grpcio-1.65.1.tar.gz", hash = "sha256:3c492301988cd720cd145d84e17318d45af342e29ef93141228f9cd73222368b"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.64.1)"] +protobuf = ["grpcio-tools (>=1.65.1)"] [[package]] name = "grpcio-reflection" @@ -792,85 +794,77 @@ setuptools = "*" [[package]] name = "hf-transfer" -version = "0.1.6" -description = "" +version = "0.1.8" +description = "Speed up file transfers with the Hugging Face Hub." optional = false python-versions = ">=3.7" files = [ - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, - {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, - {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, - {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, - {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, - {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, - {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, - {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, - {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, - {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, - {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, - {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, - {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, - {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, - {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, - {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, - {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, - {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, - {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, - {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, - {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, - {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, - {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, - {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, - {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, - {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, - {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:70858f9e94286738ed300484a45beb5cfee6a7ddac4c5886f9c6fce7823ac5ab"}, + {file = "hf_transfer-0.1.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38adc73f0a8526319d90f7cc5dc2d5e4bb66f487a513d94b98aa6725be732e4a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44d2f0c08198d8d899fe9d66e86aee2dd844bd7ce33888f261373fcec81d2a54"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1de2a4ef36f9e60b3d3bec00193c0aafd75771709f2ca51b9b162373f5af3d32"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e319269e3606a5ff2979296841766649ac73598a4a8eee2a968f86c8071fea5a"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f6026cf3be6a53ea42f92172f60c1c0675baaa9073f865e671b661dde5fd157"}, + {file = "hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f865c33ada5bd3650c2b46e59979f2d7755c3f517f8d0facc78576a0c7d26406"}, + {file = "hf_transfer-0.1.8-cp310-none-win32.whl", hash = "sha256:2054730e8d8ed21917c64be7199e06424b2bd08df1c43a72766afaed7992f2d3"}, + {file = "hf_transfer-0.1.8-cp310-none-win_amd64.whl", hash = "sha256:2b4f1a9446ba31170b5b1eca4e916504d18378a6b5fe959896bdac8a736a5ecb"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:e27c15fcc5869ad7e52bbc0bdec6106b288d1c463f8d2da92f28615a3b181361"}, + {file = "hf_transfer-0.1.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:871a0032d011ebc6409a73a8406b98b84ff2cd3ed7d9e1af8cdf4d660b9fab9b"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:686fa756e1e0214bb6327d33c66732c52274d94a8460beb50604ad988b391cf6"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:36a03b1b2911b0cf15b1b9d971a34b32dadcc4f2fd979aaff5979d6ce4017c34"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:079db90c81f41f4cf3227dfaaa855a9b8e9aef45bc7c2be29ce7232cd83ff881"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac08a4524127fdd14c234d4bcbe49d1c498acf5335c781714823179bcc8dc039"}, + {file = "hf_transfer-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837432e73cb17274a6782b6216e8ce058aa325a475dc44a5a6a753d48b86d18a"}, + {file = "hf_transfer-0.1.8-cp311-none-win32.whl", hash = "sha256:b180f9823dde35aba9bc0f1d0c04ac8a873baebd3732a7ffe4f11940abc7df0d"}, + {file = "hf_transfer-0.1.8-cp311-none-win_amd64.whl", hash = "sha256:37907d2135cebcf8b6d419bb575148d89c224f16b69357f027bd29d0e85c6529"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baf948f4f493949309cbe60529620b9b0aef854a22b6e526753364acc57c09b6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bce5c8bdefa478c5d5eaa646cc4ce1df5cfe764d98572ad0c6b8773e98d49f6"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54d6f8a1a86128d651a3799e1267c343d60f81f2c565d7c5416eb8e674e4cf0e"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f79fd1b0c2ed93efb4c5f684118d7a762ecdd218e170df8208c4e13d3dcd4959"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:414df35692670683bf5623498ef9d88a8df5d77e9516515da6e2b34d1054c11f"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c9798d5f951f66b96d40a7a53910260cb5874fda56cf5944dddb7c571f37ec3"}, + {file = "hf_transfer-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:060c661691f85a61392e57579c80eb64b5ee277434e81fb582f605c1c8ff05d5"}, + {file = "hf_transfer-0.1.8-cp312-none-win32.whl", hash = "sha256:f7840e32379820c3e1571a480238e05ea043e970c99d2e999578004a2eb17788"}, + {file = "hf_transfer-0.1.8-cp312-none-win_amd64.whl", hash = "sha256:9a3204ec423cc5e659872e8179f8704ad9ce2abb1e6a991f8838aedf1dc07830"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09949e86ad63ee139e463fd0dfaf401515ae70445854199f61d545514c65f744"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf1a74552845b93ea972e6e7131ef54e56056aa54137e93a40faf3fbcb2442ff"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959bcb3afb4ee6f2a07031a947dba98ec0b64c001bc914fbd8fc32e13a287162"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e01eecdb8162bd61dab9090fbd9f8034dd8b5755ef727a21ca8a057f80cb91ee"}, + {file = "hf_transfer-0.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50650a38e9d31f5ad8f010e4598bf304ecd99c17162e7d93f67e031571b864ee"}, + {file = "hf_transfer-0.1.8-cp37-none-win32.whl", hash = "sha256:e29b9d1d378138f2f4eae0e93ca94af3b5d45f4532eef69f1ab97fe06f9c9d9e"}, + {file = "hf_transfer-0.1.8-cp37-none-win_amd64.whl", hash = "sha256:cfd6cef43ae883103117a371f8ebae4e7f9637bc6fb480f1be5568e2fe22a8a7"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92a68f7a0043cca8a0de4decc760dca177530944cbab502afac503bd1b2fa01a"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e3138e408179f80a5480598e32f8e1abb564915cbde4d3bc8da52811c75dc3ea"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4544d148930ad34442d43b8fa911c8479c04a95b858b1d1f91e0b7da77082fad"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a851794b9f029965664f8c3002c957fccf21685e9397ceb4f9f19c986dee8ad3"}, + {file = "hf_transfer-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:791aaf87c5319ac83edb6ab2994b3db19924c49d6ff667dd3d8a610b455ff70a"}, + {file = "hf_transfer-0.1.8-cp38-none-win32.whl", hash = "sha256:8f71e5d35d3a3160dcca12fdcc8119033aeacaa6a32838a7ad9f9cb1008bbe58"}, + {file = "hf_transfer-0.1.8-cp38-none-win_amd64.whl", hash = "sha256:543287b4ceb1e25501580b99690f7f0df9d3631d29306f37cbd97e918c732944"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7ce02a18bd0bb2343e707ac85b68c946bc37623ee24150c69158f6b2b2c7a98f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64d7f8dbd64ba183ed1df75d47c84e075ff666ceaa335bff1de16b09eaac5b80"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e7858694e11419ae27e542fb8fc0d0e54d46ff7768fe73bc359d70b8f5aa578"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed116cd9d1edfa32c0136d7cb8e5f1afd2b32df43c49085d428f108fc8e1c8f"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e385d0da9c6b3472ab29285d2d46c9f9903205b8d108f88a82f3f85aafae0ab"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98f75fa4b86ef15433cd907807ac77d1fb39d7e7b790bfd39c7ae9c385bf0200"}, + {file = "hf_transfer-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a63ad947d2901425ac0a3ed70c3696dfde27fadb0482ed763bdd5cc946b278"}, + {file = "hf_transfer-0.1.8-cp39-none-win32.whl", hash = "sha256:3e74096915813ae842ea6a5bdf10c0fef960aa51a35a560955b3e61cdfe3db57"}, + {file = "hf_transfer-0.1.8-cp39-none-win_amd64.whl", hash = "sha256:05ea16307bf4a5eb097cbc6e5057e4eb5e080a138af23ef639fd38857723c288"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:928ff036c3e98e10dcfbdb4fcdfc4592d37a5cc8e365a7ba8dfd4337e849d675"}, + {file = "hf_transfer-0.1.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d49ba3ce67035f460ae1924fe2feafec155cb535eec7f31ed5109c19064cd294"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b01f5872c62cfee3ec9ca5c738818296f69f8adf84b4d8d15f2a5601d9dda339"}, + {file = "hf_transfer-0.1.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:659d4212d50847a5165666bf43d67727679b4f694ef9c413613cc27093136527"}, + {file = "hf_transfer-0.1.8.tar.gz", hash = "sha256:26d229468152e7a3ec12664cac86b8c2800695fd85f9c9a96677a775cc04f0b3"}, ] [[package]] name = "huggingface-hub" -version = "0.23.2" +version = "0.23.5" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, - {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, + {file = "huggingface_hub-0.23.5-py3-none-any.whl", hash = "sha256:d7a7d337615e11a45cc14a0ce5a605db6b038dc24af42866f731684825226e90"}, + {file = "huggingface_hub-0.23.5.tar.gz", hash = "sha256:67a9caba79b71235be3752852ca27da86bd54311d2424ca8afdb8dda056edf98"}, ] [package.dependencies] @@ -937,20 +931,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "intel-openmp" -version = "2021.4.0" -description = "Intel OpenMP* Runtime Library" -optional = true -python-versions = "*" -files = [ - {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, -] - [[package]] name = "interegular" version = "0.3.3" @@ -992,13 +972,13 @@ files = [ [[package]] name = "jsonschema" -version = "4.22.0" +version = "4.23.0" description = "An implementation of JSON Schema validation for Python" optional = true python-versions = ">=3.8" files = [ - {file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"}, - {file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"}, + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, ] [package.dependencies] @@ -1009,7 +989,7 @@ rpds-py = ">=0.7.1" [package.extras] format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] -format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] [[package]] name = "jsonschema-specifications" @@ -1044,32 +1024,32 @@ regex = ["regex"] [[package]] name = "llvmlite" -version = "0.42.0" +version = "0.43.0" description = "lightweight wrapper around basic LLVM functionality" optional = true python-versions = ">=3.9" files = [ - {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, - {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, - {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, - {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, - {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, - {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, - {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, - {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, - {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, - {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, - {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, - {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, - {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, - {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, - {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, - {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, - {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, - {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, - {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, - {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, - {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, + {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, + {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, + {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, + {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, + {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, ] [[package]] @@ -1160,22 +1140,72 @@ files = [ ] [[package]] -name = "mkl" -version = "2021.4.0" -description = "Intel® oneAPI Math Kernel Library" +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" optional = true -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, - {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, - {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, ] [package.dependencies] -intel-openmp = "==2021.*" -tbb = "==2021.*" +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" + +[[package]] +name = "marlin-kernels" +version = "0.2.0" +description = "Marlin quantization kernels" +optional = true +python-versions = ">=3.7" +files = [ + {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, +] + +[package.dependencies] +torch = "*" + +[package.source] +type = "url" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mpmath" @@ -1295,31 +1325,27 @@ files = [ [[package]] name = "multiprocess" -version = "0.70.15" +version = "0.70.16" description = "better multiprocessing and multithreading in Python" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, - {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, - {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, - {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, - {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, - {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, - {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, ] [package.dependencies] -dill = ">=0.3.7" +dill = ">=0.3.8" [[package]] name = "nest-asyncio" @@ -1352,37 +1378,37 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "numba" -version = "0.59.1" +version = "0.60.0" description = "compiling Python code using LLVM" optional = true python-versions = ">=3.9" files = [ - {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, - {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, - {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, - {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, - {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, - {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, - {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, - {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, - {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, - {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, - {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, - {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, - {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, - {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, - {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, - {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, - {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, - {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, - {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, - {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, - {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, + {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, + {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"}, + {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"}, + {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"}, + {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"}, + {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"}, + {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"}, + {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"}, + {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"}, + {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"}, + {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"}, + {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"}, + {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"}, + {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"}, + {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"}, + {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"}, + {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"}, ] [package.dependencies] -llvmlite = "==0.42.*" -numpy = ">=1.22,<1.27" +llvmlite = "==0.43.*" +numpy = ">=1.22,<2.1" [[package]] name = "numpy" @@ -1475,12 +1501,13 @@ files = [ [[package]] name = "nvidia-cudnn-cu12" -version = "8.9.2.26" +version = "9.1.0.70" description = "cuDNN runtime libraries" optional = true python-versions = ">=3" files = [ - {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, + {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, ] [package.dependencies] @@ -1551,13 +1578,14 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.5.40" +version = "12.5.82" description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"}, ] [[package]] @@ -1770,13 +1798,13 @@ test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets" [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -1883,84 +1911,95 @@ test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameter [[package]] name = "pillow" -version = "10.3.0" +version = "10.4.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" files = [ - {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, - {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, - {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, - {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, - {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, - {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, - {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, - {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, - {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, - {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, - {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, - {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, - {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, - {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, - {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, - {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, - {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, - {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, - {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, - {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, - {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, - {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, - {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, - {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, - {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, - {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, - {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, - {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, - {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, - {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, - {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, - {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, - {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, - {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, - {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, - {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, - {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, - {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, - {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, - {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, - {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, - {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, + {file = "pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e"}, + {file = "pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b"}, + {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc"}, + {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e"}, + {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46"}, + {file = "pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984"}, + {file = "pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141"}, + {file = "pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1"}, + {file = "pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c"}, + {file = "pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe"}, + {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319"}, + {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d"}, + {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696"}, + {file = "pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496"}, + {file = "pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91"}, + {file = "pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22"}, + {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, + {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, + {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, + {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, + {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, + {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, + {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, + {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, + {file = "pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736"}, + {file = "pillow-10.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b"}, + {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd"}, + {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84"}, + {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0"}, + {file = "pillow-10.4.0-cp38-cp38-win32.whl", hash = "sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e"}, + {file = "pillow-10.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab"}, + {file = "pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d"}, + {file = "pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b"}, + {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c"}, + {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1"}, + {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df"}, + {file = "pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef"}, + {file = "pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5"}, + {file = "pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885"}, + {file = "pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27"}, + {file = "pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3"}, + {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, ] [package.extras] -docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] @@ -2018,27 +2057,28 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" +version = "6.0.0" description = "Cross-platform lib for process and system monitoring in Python." optional = true -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, - {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, - {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, - {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [package.extras] @@ -2057,157 +2097,181 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = true +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pydantic" -version = "2.7.3" +version = "2.8.2" description = "Data validation using Python type hints" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"}, - {file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"}, + {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, + {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, ] [package.dependencies] annotated-types = ">=0.4.0" -pydantic-core = "2.18.4" -typing-extensions = ">=4.6.1" +pydantic-core = "2.20.1" +typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} [package.extras] email = ["email-validator (>=2.0.0)"] [[package]] name = "pydantic-core" -version = "2.18.4" +version = "2.20.1" description = "Core functionality for Pydantic validation and serialization" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.18.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:f76d0ad001edd426b92233d45c746fd08f467d56100fd8f30e9ace4b005266e4"}, - {file = "pydantic_core-2.18.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:59ff3e89f4eaf14050c8022011862df275b552caef8082e37b542b066ce1ff26"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a55b5b16c839df1070bc113c1f7f94a0af4433fcfa1b41799ce7606e5c79ce0a"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4d0dcc59664fcb8974b356fe0a18a672d6d7cf9f54746c05f43275fc48636851"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8951eee36c57cd128f779e641e21eb40bc5073eb28b2d23f33eb0ef14ffb3f5d"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4701b19f7e3a06ea655513f7938de6f108123bf7c86bbebb1196eb9bd35cf724"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e00a3f196329e08e43d99b79b286d60ce46bed10f2280d25a1718399457e06be"}, - {file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:97736815b9cc893b2b7f663628e63f436018b75f44854c8027040e05230eeddb"}, - {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6891a2ae0e8692679c07728819b6e2b822fb30ca7445f67bbf6509b25a96332c"}, - {file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bc4ff9805858bd54d1a20efff925ccd89c9d2e7cf4986144b30802bf78091c3e"}, - {file = "pydantic_core-2.18.4-cp310-none-win32.whl", hash = "sha256:1b4de2e51bbcb61fdebd0ab86ef28062704f62c82bbf4addc4e37fa4b00b7cbc"}, - {file = "pydantic_core-2.18.4-cp310-none-win_amd64.whl", hash = "sha256:6a750aec7bf431517a9fd78cb93c97b9b0c496090fee84a47a0d23668976b4b0"}, - {file = "pydantic_core-2.18.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:942ba11e7dfb66dc70f9ae66b33452f51ac7bb90676da39a7345e99ffb55402d"}, - {file = "pydantic_core-2.18.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b2ebef0e0b4454320274f5e83a41844c63438fdc874ea40a8b5b4ecb7693f1c4"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a642295cd0c8df1b86fc3dced1d067874c353a188dc8e0f744626d49e9aa51c4"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f09baa656c904807e832cf9cce799c6460c450c4ad80803517032da0cd062e2"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98906207f29bc2c459ff64fa007afd10a8c8ac080f7e4d5beff4c97086a3dabd"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19894b95aacfa98e7cb093cd7881a0c76f55731efad31073db4521e2b6ff5b7d"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fbbdc827fe5e42e4d196c746b890b3d72876bdbf160b0eafe9f0334525119c8"}, - {file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f85d05aa0918283cf29a30b547b4df2fbb56b45b135f9e35b6807cb28bc47951"}, - {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e85637bc8fe81ddb73fda9e56bab24560bdddfa98aa64f87aaa4e4b6730c23d2"}, - {file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2f5966897e5461f818e136b8451d0551a2e77259eb0f73a837027b47dc95dab9"}, - {file = "pydantic_core-2.18.4-cp311-none-win32.whl", hash = "sha256:44c7486a4228413c317952e9d89598bcdfb06399735e49e0f8df643e1ccd0558"}, - {file = "pydantic_core-2.18.4-cp311-none-win_amd64.whl", hash = "sha256:8a7164fe2005d03c64fd3b85649891cd4953a8de53107940bf272500ba8a788b"}, - {file = "pydantic_core-2.18.4-cp311-none-win_arm64.whl", hash = "sha256:4e99bc050fe65c450344421017f98298a97cefc18c53bb2f7b3531eb39bc7805"}, - {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, - {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, - {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, - {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, - {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, - {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, - {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, - {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, - {file = "pydantic_core-2.18.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:823be1deb01793da05ecb0484d6c9e20baebb39bd42b5d72636ae9cf8350dbd2"}, - {file = "pydantic_core-2.18.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ebef0dd9bf9b812bf75bda96743f2a6c5734a02092ae7f721c048d156d5fabae"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae1d6df168efb88d7d522664693607b80b4080be6750c913eefb77e34c12c71a"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9899c94762343f2cc2fc64c13e7cae4c3cc65cdfc87dd810a31654c9b7358cc"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99457f184ad90235cfe8461c4d70ab7dd2680e28821c29eca00252ba90308c78"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18f469a3d2a2fdafe99296a87e8a4c37748b5080a26b806a707f25a902c040a8"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7cdf28938ac6b8b49ae5e92f2735056a7ba99c9b110a474473fd71185c1af5d"}, - {file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:938cb21650855054dc54dfd9120a851c974f95450f00683399006aa6e8abb057"}, - {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:44cd83ab6a51da80fb5adbd9560e26018e2ac7826f9626bc06ca3dc074cd198b"}, - {file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:972658f4a72d02b8abfa2581d92d59f59897d2e9f7e708fdabe922f9087773af"}, - {file = "pydantic_core-2.18.4-cp38-none-win32.whl", hash = "sha256:1d886dc848e60cb7666f771e406acae54ab279b9f1e4143babc9c2258213daa2"}, - {file = "pydantic_core-2.18.4-cp38-none-win_amd64.whl", hash = "sha256:bb4462bd43c2460774914b8525f79b00f8f407c945d50881568f294c1d9b4443"}, - {file = "pydantic_core-2.18.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:44a688331d4a4e2129140a8118479443bd6f1905231138971372fcde37e43528"}, - {file = "pydantic_core-2.18.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a2fdd81edd64342c85ac7cf2753ccae0b79bf2dfa063785503cb85a7d3593223"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86110d7e1907ab36691f80b33eb2da87d780f4739ae773e5fc83fb272f88825f"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46387e38bd641b3ee5ce247563b60c5ca098da9c56c75c157a05eaa0933ed154"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:123c3cec203e3f5ac7b000bd82235f1a3eced8665b63d18be751f115588fea30"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc1803ac5c32ec324c5261c7209e8f8ce88e83254c4e1aebdc8b0a39f9ddb443"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53db086f9f6ab2b4061958d9c276d1dbe3690e8dd727d6abf2321d6cce37fa94"}, - {file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:abc267fa9837245cc28ea6929f19fa335f3dc330a35d2e45509b6566dc18be23"}, - {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a0d829524aaefdebccb869eed855e2d04c21d2d7479b6cada7ace5448416597b"}, - {file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:509daade3b8649f80d4e5ff21aa5673e4ebe58590b25fe42fac5f0f52c6f034a"}, - {file = "pydantic_core-2.18.4-cp39-none-win32.whl", hash = "sha256:ca26a1e73c48cfc54c4a76ff78df3727b9d9f4ccc8dbee4ae3f73306a591676d"}, - {file = "pydantic_core-2.18.4-cp39-none-win_amd64.whl", hash = "sha256:c67598100338d5d985db1b3d21f3619ef392e185e71b8d52bceacc4a7771ea7e"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, - {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, - {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, - {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, + {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, + {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, + {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, + {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, + {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, + {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, + {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, + {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, + {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, + {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, + {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, + {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, + {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, ] [package.dependencies] @@ -2446,110 +2510,110 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rpds-py" -version = "0.18.1" +version = "0.19.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = true python-versions = ">=3.8" files = [ - {file = "rpds_py-0.18.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d31dea506d718693b6b2cffc0648a8929bdc51c70a311b2770f09611caa10d53"}, - {file = "rpds_py-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:732672fbc449bab754e0b15356c077cc31566df874964d4801ab14f71951ea80"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a98a1f0552b5f227a3d6422dbd61bc6f30db170939bd87ed14f3c339aa6c7c9"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f1944ce16401aad1e3f7d312247b3d5de7981f634dc9dfe90da72b87d37887d"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38e14fb4e370885c4ecd734f093a2225ee52dc384b86fa55fe3f74638b2cfb09"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08d74b184f9ab6289b87b19fe6a6d1a97fbfea84b8a3e745e87a5de3029bf944"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d70129cef4a8d979caa37e7fe957202e7eee8ea02c5e16455bc9808a59c6b2f0"}, - {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0bb20e3a11bd04461324a6a798af34d503f8d6f1aa3d2aa8901ceaf039176d"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81c5196a790032e0fc2464c0b4ab95f8610f96f1f2fa3d4deacce6a79852da60"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f3027be483868c99b4985fda802a57a67fdf30c5d9a50338d9db646d590198da"}, - {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d44607f98caa2961bab4fa3c4309724b185b464cdc3ba6f3d7340bac3ec97cc1"}, - {file = "rpds_py-0.18.1-cp310-none-win32.whl", hash = "sha256:c273e795e7a0f1fddd46e1e3cb8be15634c29ae8ff31c196debb620e1edb9333"}, - {file = "rpds_py-0.18.1-cp310-none-win_amd64.whl", hash = "sha256:8352f48d511de5f973e4f2f9412736d7dea76c69faa6d36bcf885b50c758ab9a"}, - {file = "rpds_py-0.18.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6b5ff7e1d63a8281654b5e2896d7f08799378e594f09cf3674e832ecaf396ce8"}, - {file = "rpds_py-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8927638a4d4137a289e41d0fd631551e89fa346d6dbcfc31ad627557d03ceb6d"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:154bf5c93d79558b44e5b50cc354aa0459e518e83677791e6adb0b039b7aa6a7"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07f2139741e5deb2c5154a7b9629bc5aa48c766b643c1a6750d16f865a82c5fc"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c7672e9fba7425f79019db9945b16e308ed8bc89348c23d955c8c0540da0a07"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:489bdfe1abd0406eba6b3bb4fdc87c7fa40f1031de073d0cfb744634cc8fa261"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c20f05e8e3d4fc76875fc9cb8cf24b90a63f5a1b4c5b9273f0e8225e169b100"}, - {file = "rpds_py-0.18.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:967342e045564cef76dfcf1edb700b1e20838d83b1aa02ab313e6a497cf923b8"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2cc7c1a47f3a63282ab0f422d90ddac4aa3034e39fc66a559ab93041e6505da7"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f7afbfee1157e0f9376c00bb232e80a60e59ed716e3211a80cb8506550671e6e"}, - {file = "rpds_py-0.18.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e6934d70dc50f9f8ea47081ceafdec09245fd9f6032669c3b45705dea096b88"}, - {file = "rpds_py-0.18.1-cp311-none-win32.whl", hash = "sha256:c69882964516dc143083d3795cb508e806b09fc3800fd0d4cddc1df6c36e76bb"}, - {file = "rpds_py-0.18.1-cp311-none-win_amd64.whl", hash = "sha256:70a838f7754483bcdc830444952fd89645569e7452e3226de4a613a4c1793fb2"}, - {file = "rpds_py-0.18.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3dd3cd86e1db5aadd334e011eba4e29d37a104b403e8ca24dcd6703c68ca55b3"}, - {file = "rpds_py-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05f3d615099bd9b13ecf2fc9cf2d839ad3f20239c678f461c753e93755d629ee"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b2b771b13eee8729a5049c976197ff58a27a3829c018a04341bcf1ae409b2b"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee17cd26b97d537af8f33635ef38be873073d516fd425e80559f4585a7b90c43"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b646bf655b135ccf4522ed43d6902af37d3f5dbcf0da66c769a2b3938b9d8184"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19ba472b9606c36716062c023afa2484d1e4220548751bda14f725a7de17b4f6"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e30ac5e329098903262dc5bdd7e2086e0256aa762cc8b744f9e7bf2a427d3f8"}, - {file = "rpds_py-0.18.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d58ad6317d188c43750cb76e9deacf6051d0f884d87dc6518e0280438648a9ac"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e1735502458621921cee039c47318cb90b51d532c2766593be6207eec53e5c4c"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f5bab211605d91db0e2995a17b5c6ee5edec1270e46223e513eaa20da20076ac"}, - {file = "rpds_py-0.18.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2fc24a329a717f9e2448f8cd1f960f9dac4e45b6224d60734edeb67499bab03a"}, - {file = "rpds_py-0.18.1-cp312-none-win32.whl", hash = "sha256:1805d5901779662d599d0e2e4159d8a82c0b05faa86ef9222bf974572286b2b6"}, - {file = "rpds_py-0.18.1-cp312-none-win_amd64.whl", hash = "sha256:720edcb916df872d80f80a1cc5ea9058300b97721efda8651efcd938a9c70a72"}, - {file = "rpds_py-0.18.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:c827576e2fa017a081346dce87d532a5310241648eb3700af9a571a6e9fc7e74"}, - {file = "rpds_py-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:aa3679e751408d75a0b4d8d26d6647b6d9326f5e35c00a7ccd82b78ef64f65f8"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0abeee75434e2ee2d142d650d1e54ac1f8b01e6e6abdde8ffd6eeac6e9c38e20"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed402d6153c5d519a0faf1bb69898e97fb31613b49da27a84a13935ea9164dfc"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:338dee44b0cef8b70fd2ef54b4e09bb1b97fc6c3a58fea5db6cc083fd9fc2724"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7750569d9526199c5b97e5a9f8d96a13300950d910cf04a861d96f4273d5b104"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:607345bd5912aacc0c5a63d45a1f73fef29e697884f7e861094e443187c02be5"}, - {file = "rpds_py-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:207c82978115baa1fd8d706d720b4a4d2b0913df1c78c85ba73fe6c5804505f0"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6d1e42d2735d437e7e80bab4d78eb2e459af48c0a46e686ea35f690b93db792d"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5463c47c08630007dc0fe99fb480ea4f34a89712410592380425a9b4e1611d8e"}, - {file = "rpds_py-0.18.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:06d218939e1bf2ca50e6b0ec700ffe755e5216a8230ab3e87c059ebb4ea06afc"}, - {file = "rpds_py-0.18.1-cp38-none-win32.whl", hash = "sha256:312fe69b4fe1ffbe76520a7676b1e5ac06ddf7826d764cc10265c3b53f96dbe9"}, - {file = "rpds_py-0.18.1-cp38-none-win_amd64.whl", hash = "sha256:9437ca26784120a279f3137ee080b0e717012c42921eb07861b412340f85bae2"}, - {file = "rpds_py-0.18.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:19e515b78c3fc1039dd7da0a33c28c3154458f947f4dc198d3c72db2b6b5dc93"}, - {file = "rpds_py-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a7b28c5b066bca9a4eb4e2f2663012debe680f097979d880657f00e1c30875a0"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:673fdbbf668dd958eff750e500495ef3f611e2ecc209464f661bc82e9838991e"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d960de62227635d2e61068f42a6cb6aae91a7fe00fca0e3aeed17667c8a34611"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352a88dc7892f1da66b6027af06a2e7e5d53fe05924cc2cfc56495b586a10b72"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e0ee01ad8260184db21468a6e1c37afa0529acc12c3a697ee498d3c2c4dcaf3"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c39ad2f512b4041343ea3c7894339e4ca7839ac38ca83d68a832fc8b3748ab"}, - {file = "rpds_py-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aaa71ee43a703c321906813bb252f69524f02aa05bf4eec85f0c41d5d62d0f4c"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6cd8098517c64a85e790657e7b1e509b9fe07487fd358e19431cb120f7d96338"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4adec039b8e2928983f885c53b7cc4cda8965b62b6596501a0308d2703f8af1b"}, - {file = "rpds_py-0.18.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32b7daaa3e9389db3695964ce8e566e3413b0c43e3394c05e4b243a4cd7bef26"}, - {file = "rpds_py-0.18.1-cp39-none-win32.whl", hash = "sha256:2625f03b105328729f9450c8badda34d5243231eef6535f80064d57035738360"}, - {file = "rpds_py-0.18.1-cp39-none-win_amd64.whl", hash = "sha256:bf18932d0003c8c4d51a39f244231986ab23ee057d235a12b2684ea26a353590"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cbfbea39ba64f5e53ae2915de36f130588bba71245b418060ec3330ebf85678e"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3d456ff2a6a4d2adcdf3c1c960a36f4fd2fec6e3b4902a42a384d17cf4e7a65"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7700936ef9d006b7ef605dc53aa364da2de5a3aa65516a1f3ce73bf82ecfc7ae"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51584acc5916212e1bf45edd17f3a6b05fe0cbb40482d25e619f824dccb679de"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:942695a206a58d2575033ff1e42b12b2aece98d6003c6bc739fbf33d1773b12f"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b906b5f58892813e5ba5c6056d6a5ad08f358ba49f046d910ad992196ea61397"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f8e3fecca256fefc91bb6765a693d96692459d7d4c644660a9fff32e517843"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7732770412bab81c5a9f6d20aeb60ae943a9b36dcd990d876a773526468e7163"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bd1105b50ede37461c1d51b9698c4f4be6e13e69a908ab7751e3807985fc0346"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:618916f5535784960f3ecf8111581f4ad31d347c3de66d02e728de460a46303c"}, - {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:17c6d2155e2423f7e79e3bb18151c686d40db42d8645e7977442170c360194d4"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c4c4c3f878df21faf5fac86eda32671c27889e13570645a9eea0a1abdd50922"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fab6ce90574645a0d6c58890e9bcaac8d94dff54fb51c69e5522a7358b80ab64"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531796fb842b53f2695e94dc338929e9f9dbf473b64710c28af5a160b2a8927d"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:740884bc62a5e2bbb31e584f5d23b32320fd75d79f916f15a788d527a5e83644"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:998125738de0158f088aef3cb264a34251908dd2e5d9966774fdab7402edfab7"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2be6e9dd4111d5b31ba3b74d17da54a8319d8168890fbaea4b9e5c3de630ae5"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0cee71bc618cd93716f3c1bf56653740d2d13ddbd47673efa8bf41435a60daa"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2c3caec4ec5cd1d18e5dd6ae5194d24ed12785212a90b37f5f7f06b8bedd7139"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:27bba383e8c5231cd559affe169ca0b96ec78d39909ffd817f28b166d7ddd4d8"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:a888e8bdb45916234b99da2d859566f1e8a1d2275a801bb8e4a9644e3c7e7909"}, - {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6031b25fb1b06327b43d841f33842b383beba399884f8228a6bb3df3088485ff"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48c2faaa8adfacefcbfdb5f2e2e7bdad081e5ace8d182e5f4ade971f128e6bb3"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:d85164315bd68c0806768dc6bb0429c6f95c354f87485ee3593c4f6b14def2bd"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6afd80f6c79893cfc0574956f78a0add8c76e3696f2d6a15bca2c66c415cf2d4"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa242ac1ff583e4ec7771141606aafc92b361cd90a05c30d93e343a0c2d82a89"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21be4770ff4e08698e1e8e0bce06edb6ea0626e7c8f560bc08222880aca6a6f"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c45a639e93a0c5d4b788b2613bd637468edd62f8f95ebc6fcc303d58ab3f0a8"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910e71711d1055b2768181efa0a17537b2622afeb0424116619817007f8a2b10"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9bb1f182a97880f6078283b3505a707057c42bf55d8fca604f70dedfdc0772a"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d54f74f40b1f7aaa595a02ff42ef38ca654b1469bef7d52867da474243cc633"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8d2e182c9ee01135e11e9676e9a62dfad791a7a467738f06726872374a83db49"}, - {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:636a15acc588f70fda1661234761f9ed9ad79ebed3f2125d44be0862708b666e"}, - {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, + {file = "rpds_py-0.19.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:fb37bd599f031f1a6fb9e58ec62864ccf3ad549cf14bac527dbfa97123edcca4"}, + {file = "rpds_py-0.19.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3384d278df99ec2c6acf701d067147320b864ef6727405d6470838476e44d9e8"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e54548e0be3ac117595408fd4ca0ac9278fde89829b0b518be92863b17ff67a2"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8eb488ef928cdbc05a27245e52de73c0d7c72a34240ef4d9893fdf65a8c1a955"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5da93debdfe27b2bfc69eefb592e1831d957b9535e0943a0ee8b97996de21b5"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79e205c70afddd41f6ee79a8656aec738492a550247a7af697d5bd1aee14f766"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:959179efb3e4a27610e8d54d667c02a9feaa86bbabaf63efa7faa4dfa780d4f1"}, + {file = "rpds_py-0.19.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a6e605bb9edcf010f54f8b6a590dd23a4b40a8cb141255eec2a03db249bc915b"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9133d75dc119a61d1a0ded38fb9ba40a00ef41697cc07adb6ae098c875195a3f"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dd36b712d35e757e28bf2f40a71e8f8a2d43c8b026d881aa0c617b450d6865c9"}, + {file = "rpds_py-0.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:354f3a91718489912f2e0fc331c24eaaf6a4565c080e00fbedb6015857c00582"}, + {file = "rpds_py-0.19.0-cp310-none-win32.whl", hash = "sha256:ebcbf356bf5c51afc3290e491d3722b26aaf5b6af3c1c7f6a1b757828a46e336"}, + {file = "rpds_py-0.19.0-cp310-none-win_amd64.whl", hash = "sha256:75a6076289b2df6c8ecb9d13ff79ae0cad1d5fb40af377a5021016d58cd691ec"}, + {file = "rpds_py-0.19.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6d45080095e585f8c5097897313def60caa2046da202cdb17a01f147fb263b81"}, + {file = "rpds_py-0.19.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c5c9581019c96f865483d031691a5ff1cc455feb4d84fc6920a5ffc48a794d8a"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1540d807364c84516417115c38f0119dfec5ea5c0dd9a25332dea60b1d26fc4d"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e65489222b410f79711dc3d2d5003d2757e30874096b2008d50329ea4d0f88c"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9da6f400eeb8c36f72ef6646ea530d6d175a4f77ff2ed8dfd6352842274c1d8b"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37f46bb11858717e0efa7893c0f7055c43b44c103e40e69442db5061cb26ed34"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:071d4adc734de562bd11d43bd134330fb6249769b2f66b9310dab7460f4bf714"}, + {file = "rpds_py-0.19.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9625367c8955e4319049113ea4f8fee0c6c1145192d57946c6ffcd8fe8bf48dd"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e19509145275d46bc4d1e16af0b57a12d227c8253655a46bbd5ec317e941279d"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d438e4c020d8c39961deaf58f6913b1bf8832d9b6f62ec35bd93e97807e9cbc"}, + {file = "rpds_py-0.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:90bf55d9d139e5d127193170f38c584ed3c79e16638890d2e36f23aa1630b952"}, + {file = "rpds_py-0.19.0-cp311-none-win32.whl", hash = "sha256:8d6ad132b1bc13d05ffe5b85e7a01a3998bf3a6302ba594b28d61b8c2cf13aaf"}, + {file = "rpds_py-0.19.0-cp311-none-win_amd64.whl", hash = "sha256:7ec72df7354e6b7f6eb2a17fa6901350018c3a9ad78e48d7b2b54d0412539a67"}, + {file = "rpds_py-0.19.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:5095a7c838a8647c32aa37c3a460d2c48debff7fc26e1136aee60100a8cd8f68"}, + {file = "rpds_py-0.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f2f78ef14077e08856e788fa482107aa602636c16c25bdf59c22ea525a785e9"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7cc6cb44f8636fbf4a934ca72f3e786ba3c9f9ba4f4d74611e7da80684e48d2"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf902878b4af334a09de7a45badbff0389e7cf8dc2e4dcf5f07125d0b7c2656d"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:688aa6b8aa724db1596514751ffb767766e02e5c4a87486ab36b8e1ebc1aedac"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57dbc9167d48e355e2569346b5aa4077f29bf86389c924df25c0a8b9124461fb"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4cf5a9497874822341c2ebe0d5850fed392034caadc0bad134ab6822c0925b"}, + {file = "rpds_py-0.19.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8a790d235b9d39c70a466200d506bb33a98e2ee374a9b4eec7a8ac64c2c261fa"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1d16089dfa58719c98a1c06f2daceba6d8e3fb9b5d7931af4a990a3c486241cb"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bc9128e74fe94650367fe23f37074f121b9f796cabbd2f928f13e9661837296d"}, + {file = "rpds_py-0.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c8f77e661ffd96ff104bebf7d0f3255b02aa5d5b28326f5408d6284c4a8b3248"}, + {file = "rpds_py-0.19.0-cp312-none-win32.whl", hash = "sha256:5f83689a38e76969327e9b682be5521d87a0c9e5a2e187d2bc6be4765f0d4600"}, + {file = "rpds_py-0.19.0-cp312-none-win_amd64.whl", hash = "sha256:06925c50f86da0596b9c3c64c3837b2481337b83ef3519e5db2701df695453a4"}, + {file = "rpds_py-0.19.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:52e466bea6f8f3a44b1234570244b1cff45150f59a4acae3fcc5fd700c2993ca"}, + {file = "rpds_py-0.19.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e21cc693045fda7f745c790cb687958161ce172ffe3c5719ca1764e752237d16"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b31f059878eb1f5da8b2fd82480cc18bed8dcd7fb8fe68370e2e6285fa86da6"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dd46f309e953927dd018567d6a9e2fb84783963650171f6c5fe7e5c41fd5666"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34a01a4490e170376cd79258b7f755fa13b1a6c3667e872c8e35051ae857a92b"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bcf426a8c38eb57f7bf28932e68425ba86def6e756a5b8cb4731d8e62e4e0223"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68eea5df6347d3f1378ce992d86b2af16ad7ff4dcb4a19ccdc23dea901b87fb"}, + {file = "rpds_py-0.19.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dab8d921b55a28287733263c0e4c7db11b3ee22aee158a4de09f13c93283c62d"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6fe87efd7f47266dfc42fe76dae89060038f1d9cb911f89ae7e5084148d1cc08"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:535d4b52524a961d220875688159277f0e9eeeda0ac45e766092bfb54437543f"}, + {file = "rpds_py-0.19.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:8b1a94b8afc154fbe36978a511a1f155f9bd97664e4f1f7a374d72e180ceb0ae"}, + {file = "rpds_py-0.19.0-cp38-none-win32.whl", hash = "sha256:7c98298a15d6b90c8f6e3caa6457f4f022423caa5fa1a1ca7a5e9e512bdb77a4"}, + {file = "rpds_py-0.19.0-cp38-none-win_amd64.whl", hash = "sha256:b0da31853ab6e58a11db3205729133ce0df26e6804e93079dee095be3d681dc1"}, + {file = "rpds_py-0.19.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5039e3cef7b3e7a060de468a4a60a60a1f31786da94c6cb054e7a3c75906111c"}, + {file = "rpds_py-0.19.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab1932ca6cb8c7499a4d87cb21ccc0d3326f172cfb6a64021a889b591bb3045c"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2afd2164a1e85226fcb6a1da77a5c8896c18bfe08e82e8ceced5181c42d2179"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b1c30841f5040de47a0046c243fc1b44ddc87d1b12435a43b8edff7e7cb1e0d0"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f757f359f30ec7dcebca662a6bd46d1098f8b9fb1fcd661a9e13f2e8ce343ba1"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15e65395a59d2e0e96caf8ee5389ffb4604e980479c32742936ddd7ade914b22"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb0f6eb3a320f24b94d177e62f4074ff438f2ad9d27e75a46221904ef21a7b05"}, + {file = "rpds_py-0.19.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b228e693a2559888790936e20f5f88b6e9f8162c681830eda303bad7517b4d5a"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2575efaa5d949c9f4e2cdbe7d805d02122c16065bfb8d95c129372d65a291a0b"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:5c872814b77a4e84afa293a1bee08c14daed1068b2bb1cc312edbf020bbbca2b"}, + {file = "rpds_py-0.19.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:850720e1b383df199b8433a20e02b25b72f0fded28bc03c5bd79e2ce7ef050be"}, + {file = "rpds_py-0.19.0-cp39-none-win32.whl", hash = "sha256:ce84a7efa5af9f54c0aa7692c45861c1667080814286cacb9958c07fc50294fb"}, + {file = "rpds_py-0.19.0-cp39-none-win_amd64.whl", hash = "sha256:1c26da90b8d06227d7769f34915913911222d24ce08c0ab2d60b354e2d9c7aff"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:75969cf900d7be665ccb1622a9aba225cf386bbc9c3bcfeeab9f62b5048f4a07"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8445f23f13339da640d1be8e44e5baf4af97e396882ebbf1692aecd67f67c479"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5a7c1062ef8aea3eda149f08120f10795835fc1c8bc6ad948fb9652a113ca55"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:462b0c18fbb48fdbf980914a02ee38c423a25fcc4cf40f66bacc95a2d2d73bc8"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3208f9aea18991ac7f2b39721e947bbd752a1abbe79ad90d9b6a84a74d44409b"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3444fe52b82f122d8a99bf66777aed6b858d392b12f4c317da19f8234db4533"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88cb4bac7185a9f0168d38c01d7a00addece9822a52870eee26b8d5b61409213"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6b130bd4163c93798a6b9bb96be64a7c43e1cec81126ffa7ffaa106e1fc5cef5"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:a707b158b4410aefb6b054715545bbb21aaa5d5d0080217290131c49c2124a6e"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dc9ac4659456bde7c567107556ab065801622396b435a3ff213daef27b495388"}, + {file = "rpds_py-0.19.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:81ea573aa46d3b6b3d890cd3c0ad82105985e6058a4baed03cf92518081eec8c"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f148c3f47f7f29a79c38cc5d020edcb5ca780020fab94dbc21f9af95c463581"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0906357f90784a66e89ae3eadc2654f36c580a7d65cf63e6a616e4aec3a81be"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f629ecc2db6a4736b5ba95a8347b0089240d69ad14ac364f557d52ad68cf94b0"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c6feacd1d178c30e5bc37184526e56740342fd2aa6371a28367bad7908d454fc"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae8b6068ee374fdfab63689be0963333aa83b0815ead5d8648389a8ded593378"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78d57546bad81e0da13263e4c9ce30e96dcbe720dbff5ada08d2600a3502e526"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b6683a37338818646af718c9ca2a07f89787551057fae57c4ec0446dc6224b"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e8481b946792415adc07410420d6fc65a352b45d347b78fec45d8f8f0d7496f0"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bec35eb20792ea64c3c57891bc3ca0bedb2884fbac2c8249d9b731447ecde4fa"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:aa5476c3e3a402c37779e95f7b4048db2cb5b0ed0b9d006983965e93f40fe05a"}, + {file = "rpds_py-0.19.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:19d02c45f2507b489fd4df7b827940f1420480b3e2e471e952af4d44a1ea8e34"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a3e2fd14c5d49ee1da322672375963f19f32b3d5953f0615b175ff7b9d38daed"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:93a91c2640645303e874eada51f4f33351b84b351a689d470f8108d0e0694210"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5b9fc03bf76a94065299d4a2ecd8dfbae4ae8e2e8098bbfa6ab6413ca267709"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a4b07cdf3f84310c08c1de2c12ddadbb7a77568bcb16e95489f9c81074322ed"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba0ed0dc6763d8bd6e5de5cf0d746d28e706a10b615ea382ac0ab17bb7388633"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:474bc83233abdcf2124ed3f66230a1c8435896046caa4b0b5ab6013c640803cc"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:329c719d31362355a96b435f4653e3b4b061fcc9eba9f91dd40804ca637d914e"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef9101f3f7b59043a34f1dccbb385ca760467590951952d6701df0da9893ca0c"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:0121803b0f424ee2109d6e1f27db45b166ebaa4b32ff47d6aa225642636cd834"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8344127403dea42f5970adccf6c5957a71a47f522171fafaf4c6ddb41b61703a"}, + {file = "rpds_py-0.19.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:443cec402ddd650bb2b885113e1dcedb22b1175c6be223b14246a714b61cd521"}, + {file = "rpds_py-0.19.0.tar.gz", hash = "sha256:4fdc9afadbeb393b4bbbad75481e0ea78e4469f2e1d713a90811700830b553a9"}, ] [[package]] @@ -2772,18 +2836,19 @@ files = [ [[package]] name = "setuptools" -version = "70.0.0" +version = "71.1.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-71.1.0-py3-none-any.whl", hash = "sha256:33874fdc59b3188304b2e7c80d9029097ea31627180896fb549c578ceb8a0855"}, + {file = "setuptools-71.1.0.tar.gz", hash = "sha256:032d42ee9fb536e33087fb66cac5f840eb9391ed05637b3f2a76a7c8fb477936"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -2798,30 +2863,20 @@ files = [ [[package]] name = "sympy" -version = "1.12.1" +version = "1.13.1" description = "Computer algebra system (CAS) in Python" optional = true python-versions = ">=3.8" files = [ - {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, - {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, + {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, + {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, ] [package.dependencies] -mpmath = ">=1.1.0,<1.4.0" +mpmath = ">=1.1.0,<1.4" -[[package]] -name = "tbb" -version = "2021.12.0" -description = "Intel® oneAPI Threading Building Blocks (oneTBB)" -optional = true -python-versions = "*" -files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, -] +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "texttable" @@ -2964,44 +3019,43 @@ files = [ [[package]] name = "torch" -version = "2.3.0" +version = "2.4.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, - {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, - {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, - {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, - {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, - {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, - {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, - {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, - {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, - {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, - {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, - {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, - {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, - {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, - {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, - {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, - {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, - {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, - {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, - {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, + {file = "torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:4ed94583e244af51d6a8d28701ca5a9e02d1219e782f5a01dd401f90af17d8ac"}, + {file = "torch-2.4.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c4ca297b7bd58b506bfd6e78ffd14eb97c0e7797dcd7965df62f50bb575d8954"}, + {file = "torch-2.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2497cbc7b3c951d69b276ca51fe01c2865db67040ac67f5fc20b03e41d16ea4a"}, + {file = "torch-2.4.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:685418ab93730efbee71528821ff54005596970dd497bf03c89204fb7e3f71de"}, + {file = "torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e743adadd8c8152bb8373543964551a7cb7cc20ba898dc8f9c0cdbe47c283de0"}, + {file = "torch-2.4.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7334325c0292cbd5c2eac085f449bf57d3690932eac37027e193ba775703c9e6"}, + {file = "torch-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:97730014da4c57ffacb3c09298c6ce05400606e890bd7a05008d13dd086e46b1"}, + {file = "torch-2.4.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f169b4ea6dc93b3a33319611fcc47dc1406e4dd539844dcbd2dec4c1b96e166d"}, + {file = "torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:997084a0f9784d2a89095a6dc67c7925e21bf25dea0b3d069b41195016ccfcbb"}, + {file = "torch-2.4.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc3988e8b36d1e8b998d143255d9408d8c75da4ab6dd0dcfd23b623dfb0f0f57"}, + {file = "torch-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3374128bbf7e62cdaed6c237bfd39809fbcfaa576bee91e904706840c3f2195c"}, + {file = "torch-2.4.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:91aaf00bfe1ffa44dc5b52809d9a95129fca10212eca3ac26420eb11727c6288"}, + {file = "torch-2.4.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cc30457ea5489c62747d3306438af00c606b509d78822a88f804202ba63111ed"}, + {file = "torch-2.4.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a046491aaf96d1215e65e1fa85911ef2ded6d49ea34c8df4d0638879f2402eef"}, + {file = "torch-2.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:688eec9240f3ce775f22e1e1a5ab9894f3d5fe60f3f586deb7dbd23a46a83916"}, + {file = "torch-2.4.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3af4de2a618fb065e78404c4ba27a818a7b7957eaeff28c6c66ce7fb504b68b8"}, + {file = "torch-2.4.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:618808d3f610d5f180e47a697d4ec90b810953bb1e020f424b2ac7fb0884b545"}, + {file = "torch-2.4.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed765d232d23566052ba83632ec73a4fccde00b4c94ad45d63b471b09d63b7a7"}, + {file = "torch-2.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2feb98ac470109472fb10dfef38622a7ee08482a16c357863ebc7bc7db7c8f7"}, + {file = "torch-2.4.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8940fc8b97a4c61fdb5d46a368f21f4a3a562a17879e932eb51a5ec62310cb31"}, ] [package.dependencies] filelock = "*" fsspec = "*" jinja2 = "*" -mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} @@ -3009,12 +3063,12 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] +optree = ["optree (>=0.11.0)"] [[package]] name = "tqdm" @@ -3038,18 +3092,18 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.41.2" +version = "4.43.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.41.2-py3-none-any.whl", hash = "sha256:05555d20e43f808de1ef211ab64803cdb513170cef70d29a888b589caebefc67"}, - {file = "transformers-4.41.2.tar.gz", hash = "sha256:80a4db216533d573e9cc7388646c31ed9480918feb7c55eb211249cb23567f87"}, + {file = "transformers-4.43.1-py3-none-any.whl", hash = "sha256:eb44b731902e062acbaff196ae4896d7cb3494ddf38275aa00a5fcfb5b34f17d"}, + {file = "transformers-4.43.1.tar.gz", hash = "sha256:662252c4d0e31b6684f68f68d5cc8206dd7f83da80eb3235be3dc5b3c9fdbdbd"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.23.0,<1.0" +huggingface-hub = ">=0.23.2,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" @@ -3062,14 +3116,15 @@ tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.2.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -3080,41 +3135,46 @@ natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.4.4)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.4.4)"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.4.4)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -timm = ["timm"] +timm = ["timm (<=0.9.16)"] tokenizers = ["tokenizers (>=0.19,<0.20)"] torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.3.0" +version = "3.0.0" description = "A language and compiler for custom Deep Learning operations" optional = true python-versions = "*" files = [ - {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, - {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, - {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, - {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, - {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, - {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -3122,8 +3182,8 @@ filelock = "*" [package.extras] build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] -tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typer" @@ -3147,13 +3207,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6. [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -3169,13 +3229,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] @@ -3499,22 +3559,23 @@ multidict = ">=4.0" [[package]] name = "zipp" -version = "3.19.1" +version = "3.19.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, - {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, + {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, + {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, ] [package.extras] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"] outlines = ["outlines"] peft = ["peft"] quantize = ["accelerate", "datasets", "texttable"] @@ -3523,4 +3584,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "f62a7a74e1e1bcb3b7cb4f7da2b538065830748062a2b57fdbb4c76eae5abddc" +content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" diff --git a/server/pyproject.toml b/server/pyproject.toml index 7b5e83fb7..15da4a8fb 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,22 +26,32 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.19.1" huggingface-hub = "^0.23" -transformers = "^4.41" +transformers = "^4.43" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } peft = { version = "^0.10", optional = true } -torch = { version = "^2.3.0", optional = true } +torch = { version = "^2.4.0", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" outlines= { version = "^0.0.34", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" +# Remove later, temporary workaround for outlines. +numpy = "^1.26" + +marlin-kernels = [ + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, +] [tool.poetry.extras] torch = ["torch"] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +marlin = ["marlin-kernels"] peft = ["peft"] quantize = ["texttable", "datasets", "accelerate"] outlines = ["outlines"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 88fcc4f36..828b6fca9 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 5751bf816..828b6fca9 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 88fcc4f36..828b6fca9 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -1,48 +1,50 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" -requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py new file mode 100644 index 000000000..cc1b076da --- /dev/null +++ b/server/tests/utils/test_adapter.py @@ -0,0 +1,187 @@ +import pytest +from unittest.mock import Mock +from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights + + +def test_get_attn_weights(): + # create a mock layer + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + # call the function + result = get_attn_weights(2, mock_layer) + + # assert the result + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_with_gate_up_proj(): + # create a mock layer with gate_up_proj + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # call the function + result = get_mlp_weights(3, mock_layer) + + # assert the result + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_mlp_weights_without_gate_up_proj(): + # create a mock layer without gate_up_proj + mock_layer = Mock() + mock_layer.mlp = Mock(spec=[]) + + # call the function + result = get_mlp_weights(1, mock_layer) + + # assert the result + assert result == {} + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_attn_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(layer_index, mock_layer) + + for k in ["q", "k", "v"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.self_attn.{k}_proj" + ) + + assert (layer_index, "o_proj") in result + assert ( + result[(layer_index, "o_proj")][0] + == f"model.layers.{layer_index}.self_attn.o_proj" + ) + + +@pytest.mark.parametrize("layer_index", [0, 1, 5]) +def test_get_mlp_weights_different_layers(layer_index): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(layer_index, mock_layer) + + for k in ["gate", "up", "down"]: + assert (layer_index, f"{k}_proj") in result + assert ( + result[(layer_index, f"{k}_proj")][0] + == f"model.layers.{layer_index}.mlp.{k}_proj" + ) + + +def test_get_attn_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_llama_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected + + +def test_get_attn_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.self_attn.query_key_value = Mock() + mock_layer.self_attn.o_proj = Mock() + + result = get_attn_weights(2, mock_layer) + + expected = { + (2, "q_proj"): ( + "model.layers.2.self_attn.q_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "k_proj"): ( + "model.layers.2.self_attn.k_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "v_proj"): ( + "model.layers.2.self_attn.v_proj", + mock_layer.self_attn.query_key_value, + ), + (2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj), + } + assert result == expected + + +def test_get_mlp_weights_gemma_compatibility(): + mock_layer = Mock() + mock_layer.mlp.gate_proj = Mock() + mock_layer.mlp.up_proj = Mock() + mock_layer.mlp.down_proj = Mock() + + # ensure that the mock_layer.mlp.gate_up_proj attribute does not exist. + # This is necessary because the use of `Mock` automatically creates any + # attributes that are accessed, even if they don't exist in the actual + # implementation. If `gate_up_proj` were created, `get_mlp_weights` might + # follow the wrong execution path and return an incorrect result. + del mock_layer.mlp.gate_up_proj + + result = get_mlp_weights(3, mock_layer) + + expected = { + (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), + (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), + (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), + } + assert result == expected diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 721820f51..291a41b05 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,11 +1,9 @@ import os -import requests import tempfile import pytest import huggingface_hub.constants -from huggingface_hub import hf_api import text_generation_server.utils.hub from text_generation_server.utils.hub import ( diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 1e3aaf6b6..118540eed 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,7 +2,6 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) -from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 36b27be8c..556fcea1e 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -7,7 +7,10 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader -from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader +from text_generation_server.layers.marlin.marlin import ( + MarlinWeight, + MarlinWeightsLoader, +) from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path @@ -82,15 +85,6 @@ dummy_file_system = { ], dtype=torch.float32, ), - "weight.weight": torch.tensor( - [ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - ], - dtype=torch.float32, - ), }, "test_get_weights_row": { "weight.weight": torch.tensor( @@ -363,7 +357,10 @@ class MockWeights(Weights): self.process_group = process_group self.prefix = prefix self.weights_loader = ( - DefaultWeightsLoader() if weights_loader is None else weights_loader + # We don't need to get linear layers, so just wrap raw tensors. + DefaultWeightsLoader(lambda x: x) + if weights_loader is None + else weights_loader ) self._handles = {} @@ -632,6 +629,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -641,6 +639,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -669,6 +668,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -678,6 +678,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -774,6 +775,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -783,6 +785,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -851,6 +854,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -860,6 +864,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -922,6 +927,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -931,6 +937,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -949,7 +956,7 @@ def test_get_multi_weights_col_exl2(): prefix = "weight" try: - w = weights.get_multi_weights_col( + weights.get_multi_weights_col( prefixes=[prefix], dim=0, ) @@ -983,6 +990,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -992,6 +1000,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1051,6 +1060,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -1060,6 +1070,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1125,6 +1136,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -1134,6 +1146,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py index 5261d4b50..b7e270900 100644 --- a/server/text_generation_server/adapters/config.py +++ b/server/text_generation_server/adapters/config.py @@ -4,15 +4,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple import torch from text_generation_server.adapters.weights import AdapterWeights -if TYPE_CHECKING: - from text_generation_server.models.model import Model - @dataclass class ModuleMap: @@ -31,14 +28,3 @@ class AdapterConfig(ABC): weight_names: Tuple[str], ) -> Tuple[ModuleMap, Set[str]]: pass - - @abstractmethod - def load_batched_adapter_weights( - self, - model: "Model", - module_map: ModuleMap, - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 87543be2b..a00338e7c 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch from peft import LoraConfig as _LoraConfig @@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import ( use_cutlass_shrink, ) -if TYPE_CHECKING: - from text_generation_server.models.model import Model - def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size @@ -102,22 +99,6 @@ class LoraConfig(AdapterConfig): adapter_weight_names.add(lora_b_name) return module_map, adapter_weight_names - def load_batched_adapter_weights( - self, - model: "Model", - module_map: Dict[str, Dict], - layer_type: str, - unused_weight_names: Set[str], - dynamic: bool, - ) -> Optional[AdapterWeights]: - return LoraWeights.load( - self, - model, - module_map, - layer_type, - unused_weight_names, - ) - @classmethod def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) @@ -192,22 +173,38 @@ class LoraWeights(AdapterWeights): def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: return [BatchLoraWeights] + # prepare pre-loaded lora weights for use in the model. + # + # this method processes and organizes lora weights for a specific layer type across all layers: + # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor. + # - retrieves weights from `module_map` based on the `layer_type`. + # - processes `nlayers` number of layers. + # - converts weights to the specified `dtype`. + # - shards weights across `world_size` number of processes using the `process_group`. + # - maps weights to specific layers using `target_to_layer`. + # - tracks `unused_weight_names` to identify any unused weights. + # + # the method handles weight transposition, scaling, and padding to ensure compatibility + # with SGMV or BGMV operations. @classmethod - def load( + def prepare_weights( cls, config: LoraConfig, - model: "Model", module_map: Dict[str, Dict], layer_type: str, unused_weight_names: Set[str], + nlayers: int, + dtype: torch.dtype, + world_size: int, + process_group: ProcessGroup, + target_to_layer: Dict[str, Tuple[str, torch.Tensor]], ) -> Optional[AdapterWeights]: - nlayers = model.get_num_layers_for_type(layer_type) lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers for layer_id in range(nlayers): key = (layer_id, layer_type) - weight_name, layer = model.target_to_layer[key] + weight_name, layer = target_to_layer[key] base_weight = layer.base_layer.linear.weight base_device = base_weight.device @@ -216,10 +213,10 @@ class LoraWeights(AdapterWeights): return None lora_a, lora_a_name = module_map[weight_name]["lora_A"] - lora_a = lora_a.to(base_device, model.dtype) + lora_a = lora_a.to(base_device, dtype) lora_b, lora_b_name = module_map[weight_name]["lora_B"] - lora_b = lora_b.to(base_device, model.dtype) + lora_b = lora_b.to(base_device, dtype) scale = get_scaling_factor( config.lora_alpha, @@ -236,12 +233,8 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [ - pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list - ] - lora_b_list = [ - pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list - ] + lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] + lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] if lora_a_list: # update rank if it was padded @@ -252,8 +245,8 @@ class LoraWeights(AdapterWeights): *shard_lora_weights( weights_a=lora_a_list, weights_b=lora_b_list, - split_dim=0 if model.is_row_parallel(layer_type) else 1, - process_group=model.process_group, + split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1, + process_group=process_group, ), config, ) @@ -293,10 +286,6 @@ class BatchLoraWeights(BatchAdapterWeights): for rank_data in self.rank_data.values() ) - @classmethod - def key(cls) -> str: - return "lora" - @classmethod def load( self, diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 8f6587567..da75dbcdf 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC): def has_adapter(self, adapter_index: int) -> bool: pass - @abstractclassmethod - def key(cls) -> str: - pass - @abstractclassmethod def load( cls, @@ -71,13 +67,6 @@ class LayerAdapterWeights: return del self.adapter_weights[adapter_idx] - @property - def max_speculative_tokens(self) -> int: - return max( - adapter_weights.speculative_tokens - for adapter_weights in self.adapter_weights.values() - ) - def is_empty(self) -> bool: return len(self.adapter_weights) == 0 @@ -101,7 +90,7 @@ class LayerAdapterWeights: adapter_weights, meta, prefill, prefill_head_indices ) if batched_weights is not None: - batch_data[batch_type.key()] = batched_weights + batch_data = batched_weights return batch_data @@ -133,8 +122,7 @@ class AdapterBatchData: def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() - for layer_data in self.data.values(): - lora_data = layer_data.get("lora") + for lora_data in self.data.values(): if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd7..10aa3a3b2 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -7,6 +7,7 @@ from loguru import logger from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.adapter import parse_lora_adapters app = typer.Typer() @@ -79,17 +80,19 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) - lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) - # split on comma and strip whitespace - lora_adapter_ids = ( - [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] - ) + # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled + # and warn the user + if lora_adapters: + logger.warning("LoRA adapters enabled (experimental feature).") - if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." - ) + if "CUDA_GRAPHS" in os.environ: + logger.warning( + "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs." + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value @@ -105,7 +108,7 @@ def serve( ) server.serve( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, @@ -161,7 +164,7 @@ def download_weights( # currently by default we don't merge the weights with the base model if merge_lora: try: - adapter_config_filename = hf_hub_download( + hf_hub_download( model_id, revision=revision, filename="adapter_config.json" ) utils.download_and_unload_peft( @@ -281,9 +284,9 @@ def download_weights( if auto_convert: if not trust_remote_code: logger.warning( - f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " - f"Pickle files are unsafe and can essentially contain remote code execution!" - f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", + "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " + "Pickle files are unsafe and can essentially contain remote code execution!" + "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", ) logger.warning( @@ -315,7 +318,7 @@ def download_weights( # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) - except Exception as e: + except Exception: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) @@ -332,6 +335,7 @@ def quantize( upload_to_model_id: Optional[str] = None, percdamp: float = 0.01, act_order: bool = False, + groupsize: int = 128, ): if revision is None: revision = "main" @@ -346,13 +350,14 @@ def quantize( quantize( model_id=model_id, bits=4, - groupsize=128, + groupsize=groupsize, output_dir=output_dir, revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index 32c8d121b..0000ca915 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -18,3 +18,17 @@ from text_generation_server.layers.lora import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) + +__all__ = [ + "get_linear", + "FastLinear", + "TensorParallelColumnLinear", + "TensorParallelRowLinear", + "TensorParallelEmbedding", + "SpeculativeHead", + "LoraLinear", + "TensorParallelMultiAdapterLinear", + "TensorParallelAdapterRowLinear", + "load_layer_norm", + "load_conv2d", +] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index c8bccefec..f9b1715ef 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -13,3 +13,12 @@ elif SYSTEM == "ipex": from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") + + +__all__ = [ + "attention", + "paged_attention", + "reshape_and_cache", + "SUPPORTS_WINDOWING", + "Seqlen", +] diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 94b69899e..dff742dc1 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -2,6 +2,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE from text_generation_server.layers.attention import Seqlen +from typing import Optional major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -9,7 +10,6 @@ _PARTITION_SIZE = 512 try: from vllm._C import cache_ops - from vllm._C import ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -34,7 +34,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -43,6 +42,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -82,13 +82,16 @@ def paged_attention( # by the current path # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. - out2 = flash_attn_2_cuda.varlen_fwd( + if softcap is None: + softcap = 0.0 + out = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, None, seqlen.cu_seqlen_q, seqlen.cu_seqlen_k, + None, # pad_k None, block_tables, None, @@ -100,14 +103,19 @@ def paged_attention( True, # causal -1, # Window_left -1, # Window right + softcap, False, # return softmax None, # generator ) - return out2[0] + return out[0] else: + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths from vllm._C import ops + out = torch.empty_like(query) + use_v1 = max_s <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512 ) @@ -193,19 +201,21 @@ except ImportError: SUPPORTS_WINDOWING = V2 + if V2: def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): + out = torch.empty_like(q) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( @@ -218,6 +228,7 @@ if V2: None, None, None, + None, max_s, max_s, 0.0, @@ -226,9 +237,10 @@ if V2: causal, window_size_left, 0, + softcap, False, None, - ) + )[0] else: @@ -236,16 +248,18 @@ else: q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, + softcap=None, ): if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" ) + if softcap is not None: + raise NotImplementedError("softcap is only available with flash attn v2") # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: @@ -273,6 +287,8 @@ else: .reshape(original_shape[0], -1, original_shape[2]) ) + out = torch.empty_like(q) + return flash_attn_cuda.fwd( q, k, @@ -289,4 +305,4 @@ else: False, 0, None, - ) + )[0] diff --git a/server/text_generation_server/layers/attention/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py index 3fe322311..3a6f9a730 100644 --- a/server/text_generation_server/layers/attention/flash_attn_triton.py +++ b/server/text_generation_server/layers/attention/flash_attn_triton.py @@ -747,11 +747,8 @@ class _attention(torch.autograd.Function): padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = max(padded_d_model, 16) - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) + def grid(META): + return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch encoded_softmax = None diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 45a0a03ec..e0956b26c 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -10,13 +10,14 @@ def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, @@ -49,7 +50,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -59,6 +59,7 @@ def paged_attention( seqlen: Seqlen, max_s: int, ): + out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5f..69e641629 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -14,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck" try: from vllm._C import cache_ops - from vllm._C import ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -39,7 +39,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -72,6 +71,8 @@ def paged_attention( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths + out = torch.empty_like(query) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -136,7 +137,10 @@ if ENGINE != "triton": try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" @@ -171,7 +175,6 @@ if ENGINE == "ck": q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -181,6 +184,8 @@ if ENGINE == "ck": if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, @@ -206,13 +211,14 @@ elif ENGINE == "triton": q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index c859db1be..391371a55 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -1,6 +1,5 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py -import math from typing import Optional import torch import torch.nn as nn diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index ca39919ce..791d9b6d8 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,15 +1,17 @@ -import torch -from loguru import logger -from functools import lru_cache +from dataclasses import dataclass + import bitsandbytes as bnb +import torch from bitsandbytes.nn import Int8Params, Params4bit +from text_generation_server.utils.weights import UnquantizedWeight -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +@dataclass +class BNBWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0) class Linear8bitLt(torch.nn.Module): @@ -70,6 +72,22 @@ class Linear8bitLt(torch.nn.Module): return out +@dataclass +class BNBFP4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="fp4") + + +@dataclass +class BNBNF4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="nf4") + + class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index fd22b5c67..b1e5235a0 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -1,5 +1,23 @@ +from dataclasses import dataclass + import torch from EETQ import quant_weights, w8_a16_gemm +from text_generation_server.utils.weights import UnquantizedWeight + + +@dataclass +class EETQWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + try: + from text_generation_server.layers.eetq import EETQLinear + + return EETQLinear(self.weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) class EETQLinear(torch.nn.Module): diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 55cba1cce..a6e07f453 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,12 +1,12 @@ -import torch -from typing import List, Union from dataclasses import dataclass +from typing import List, Union -from text_generation_server.utils.weights import WeightsLoader, Weights +import torch +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class Exl2Weight: +class Exl2Weight(Weight): """ Exllama2 exl2 quantized weights. """ @@ -25,10 +25,39 @@ class Exl2Weight: def device(self) -> torch.device: return self.q_weight.device + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.gptq import ExllamaQuantLinear + + return ExllamaQuantLinear(self, bias) + class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + def get_weights_col_packed( self, weights: Weights, @@ -38,46 +67,12 @@ class Exl2WeightsLoader(WeightsLoader): raise RuntimeError("Column-packed weights are not supported for exl") def get_weights_col(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") def get_weights_row(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dd61d0819..59b08b55b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,12 +1,69 @@ import torch +from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once +import importlib.util + + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False + + +def is_fbgemm_gpu_available(): + try: + return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + except ModuleNotFoundError: + return False + + +if is_fbgemm_gpu_available(): + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +else: + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + if major == 8: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear + + +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -18,20 +75,178 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ).reshape(-1) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0).reshape(-1) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + +@dataclass +class Fp8Weight(Weight): + weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None + + def get_linear(self, bias: torch.Tensor): + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) + + class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: - qinput, scale = fp8_quantize(input) + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + + qinput, scale = fp8_quantize(input, scalar=True) output, _ = torch._scaled_mm( qinput, self.qweight.t(), diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index efcb3118f..f6616d3e9 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,34 +1,12 @@ -from dataclasses import dataclass -from loguru import logger import os +from dataclasses import dataclass from typing import List, Optional, Union -from safetensors import SafetensorError -from text_generation_server.utils.weights import Weights, WeightsLoader + import torch -from text_generation_server.utils.import_utils import ( - SYSTEM, -) +from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once - - -@dataclass -class GPTQWeight: - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: Optional[torch.Tensor] - bits: int - groupsize: int - use_exllama: bool - - def __post_init__(self): - if self.scales.dtype == torch.float: - self.scales = self.scales.half() - - @property - def device(self) -> torch.device: - return self.qweight.device - +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: major, _minor = torch.cuda.get_device_capability() @@ -44,17 +22,13 @@ elif CAN_EXLLAMA: try: if V2: from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, + QuantLinear as ExllamaQuantLinear, # noqa: F401 ) HAS_EXLLAMA = "2" else: from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 ) HAS_EXLLAMA = "1" @@ -62,7 +36,69 @@ elif CAN_EXLLAMA: except ImportError: pass -from text_generation_server.layers.gptq.quant_linear import QuantLinear + +@dataclass +class GPTQWeight(Weight): + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: Optional[torch.Tensor] + bits: int + groupsize: int + use_awq_kernel: bool + use_exllama: bool + + def __post_init__(self): + if self.scales.dtype == torch.float: + self.scales = self.scales.half() + + @property + def device(self) -> torch.device: + return self.qweight.device + + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + ) class GPTQWeightsLoader(WeightsLoader): @@ -87,17 +123,88 @@ class GPTQWeightsLoader(WeightsLoader): self.quantize = quantize self.sym = sym + def get_weights(self, weights: Weights, prefix: str): + self._get_gptq_params(weights) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + def get_weights_col_packed( self, weights: Weights, prefix: str, block_sizes: Union[int, List[int]], ): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -112,24 +219,6 @@ class GPTQWeightsLoader(WeightsLoader): scales = scales.to(dtype=weights.dtype) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - g_idx = weights.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = weights.get_packed_sharded( f"{prefix}.qzeros", dim=1, block_sizes=block_sizes @@ -162,15 +251,11 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=False, ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -185,28 +270,6 @@ class GPTQWeightsLoader(WeightsLoader): ) self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - sym=self.sym, - sharded_infeatures=False, - ) qzeros = torch.cat( [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -255,49 +318,12 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) def get_weights_row(self, weights: Weights, prefix: str): - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - self._get_gptq_params(weights) - if can_use_gptq_marlin( - bits=self.bits, - groupsize=self.groupsize, - quant_method=self.quant_method, - quantize=self.quantize, - sym=self.sym, - ): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) - if self.desc_act or self.groupsize == -1: - scales = weights.get_tensor(f"{prefix}.scales") - else: - scales = weights.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = weights.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=self.bits, - desc_act=self.desc_act, - groupsize=self.groupsize, - sym=self.sym, - sharded_infeatures=sharded_in_features, - ) use_exllama = True if self.bits != 4: @@ -336,8 +362,8 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama = False from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, CAN_EXLLAMA, + HAS_EXLLAMA, GPTQWeight, ) @@ -389,15 +415,20 @@ class GPTQWeightsLoader(WeightsLoader): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) def _get_gptq_params(self, weights: Weights): - try: + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False - self.sym = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) self.quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - pass diff --git a/server/text_generation_server/layers/gptq/custom_autotune.py b/server/text_generation_server/layers/gptq/custom_autotune.py index 1eb40f1ed..0388ef20b 100644 --- a/server/text_generation_server/layers/gptq/custom_autotune.py +++ b/server/text_generation_server/layers/gptq/custom_autotune.py @@ -91,7 +91,7 @@ class Autotuner(triton.KernelInterface): kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 ) except triton.OutOfResources: - return (float("inf"), float("inf"), float("inf")) + return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 4d45822be..dc3b832f9 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -9,11 +9,12 @@ from loguru import logger from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error("exllamav2_kernels not installed.") + log_master(logger.warning, "exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py index f60758b61..736c357b0 100644 --- a/server/text_generation_server/layers/gptq/quant_linear.py +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): output = torch.empty( (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) + + def grid(META): + return ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( input, qweight, diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index c65d5e78d..b0086ea08 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -15,8 +15,9 @@ from text_generation_server.utils.hub import weight_files from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.layers.gptq.utils import torch_snr_error -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight DEV = torch.device("cuda:0") @@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=False, trust_remote_code=trust_remote_code ) - except: + except Exception: tokenizer = AutoTokenizer.from_pretrained( model_id, use_fast=True, trust_remote_code=trust_remote_code ) @@ -700,6 +701,8 @@ def sequential( pass def add_batch(name): + nonlocal gptq + def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) @@ -871,6 +874,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -893,7 +897,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, - weights_loader=DefaultWeightsLoader(), + weights_loader=DefaultWeightsLoader(UnquantizedWeight), ) hooks = [] for name, module in model.named_modules(): @@ -946,6 +950,7 @@ def quantize( percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -955,8 +960,6 @@ def quantize( state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor([bits]) - state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) max_shard_size = "10GB" shards, index = shard_checkpoint( @@ -988,6 +991,15 @@ def quantize( f"index located at {save_index_file}." ) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.quantization_config = { + "bits": bits, + "group_size": groupsize, + "damp_percent": percdamp, + "desc_act": act_order, + "static_groups": False, + "sym": sym, + "quant_method": "gptq", + } config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer") diff --git a/server/text_generation_server/layers/gptq/utils.py b/server/text_generation_server/layers/gptq/utils.py new file mode 100644 index 000000000..cbc0f391f --- /dev/null +++ b/server/text_generation_server/layers/gptq/utils.py @@ -0,0 +1,56 @@ +import torch + + +# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py +def torch_snr_error( + y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean" +) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ + if y_pred.shape != y_real.shape: + raise ValueError( + f"Can not compute snr loss for tensors with different shape. " + f"({y_pred.shape} and {y_real.shape})" + ) + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == "mean": + return torch.mean(snr) + elif reduction == "sum": + return torch.sum(snr) + elif reduction == "none": + return snr + else: + raise ValueError("Unsupported reduction method.") diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index e94e5465c..12d7f83aa 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,7 +1,6 @@ -from typing import Optional import torch -from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from torch.nn import functional as F if SYSTEM == "rocm": try: @@ -90,167 +89,14 @@ class FastLinearROCm(torch.nn.Module): return F.linear(inp, self.weight, self.bias) -def get_linear(weight, bias, quantize): - if quantize is None: +def get_linear(weight, bias): + # Weights that are loaded through methods that are not + # quantization-aware are still bare tensors. We may want + # to change this in the future. + if isinstance(weight, torch.Tensor): if SYSTEM == "rocm": - linear = FastLinearROCm(weight, bias) + return FastLinearROCm(weight, bias) else: - linear = FastLinear(weight, bias) - elif quantize == "eetq": - try: - from text_generation_server.layers.eetq import EETQLinear + return FastLinear(weight, bias) - linear = EETQLinear(weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Linear - - linear = Fp8Linear(weight, bias) - elif quantize == "bitsandbytes": - try: - from text_generation_server.layers.bnb import ( - warn_deprecate_bnb, - Linear8bitLt, - ) - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - warn_deprecate_bnb() - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-fp4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "bitsandbytes-nf4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - if not isinstance(weight, Exl2Weight): - raise NotImplementedError( - f"The passed weight is not `exl2` compatible, loader needs to be updated." - ) - - from text_generation_server.layers.gptq import ExllamaQuantLinear - - linear = ExllamaQuantLinear(weight, bias) - - elif quantize == "gptq": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlinLinear, - GPTQMarlinWeight, - ) - - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQWeight): - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - linear = ExllamaQuantLinear(weight, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, - ) - else: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) - - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `awq` compatible, loader needs to be updated." - ) - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear - - linear = WQLinear( - w_bit=weight.bits, - group_size=weight.groupsize, - qweight=weight.qweight, - qzeros=weight.qzeros, - scales=weight.scales, - bias=bias, - ) - except ImportError: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Linear, - GPTQMarlin24Weight, - MarlinLinear, - MarlinWeight, - ) - - if isinstance(weight, GPTQMarlin24Weight): - linear = GPTQMarlin24Linear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, MarlinWeight): - linear = MarlinLinear(weight=weight, bias=bias) - else: - raise NotImplementedError( - f"The passed weight is not `marlin` compatible, loader needs to be updated." - ) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear + return weight.get_linear(bias) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index 0bb6db41a..a4537b55b 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -1,12 +1,8 @@ -import math -import os -from typing import TYPE_CHECKING, Optional, Tuple, List +from typing import TYPE_CHECKING, Optional, List import torch import torch.distributed -from accelerate import init_empty_weights from torch import nn -from torch.nn import functional as F from torch.distributed import ProcessGroup from text_generation_server.utils.sgmv import ( @@ -43,10 +39,7 @@ class LoraLinear(nn.Module): ) -> torch.Tensor: if adapter_data is None: return result - data = adapter_data.data.get(layer_type) - data: Optional["BatchLoraWeights"] = ( - data.get("lora") if data is not None else None - ) + data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): # In tensor-parallel configurations, each GPU processes a specific segment of the output. diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py new file mode 100644 index 000000000..3ff3ed58f --- /dev/null +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -0,0 +1,15 @@ +from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear +from text_generation_server.layers.marlin.gptq import ( + GPTQMarlinWeightsLoader, + can_use_gptq_marlin, + repack_gptq_for_marlin, +) +from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader + +__all__ = [ + "GPTQMarlinFP8Linear", + "GPTQMarlinWeightsLoader", + "MarlinWeightsLoader", + "can_use_gptq_marlin", + "repack_gptq_for_marlin", +] diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py new file mode 100644 index 000000000..fe55a58a3 --- /dev/null +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -0,0 +1,140 @@ +from typing import Optional + +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize +from text_generation_server.layers.marlin.gptq import _check_valid_shape +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + permute_scales, +) +from text_generation_server.utils.log import log_once + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +MARLIN_TILE_SIZE = 16 + + +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + qweight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = permute_scales(scales) + + return repacked, scales diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py new file mode 100644 index 000000000..c7663b60b --- /dev/null +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -0,0 +1,465 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy +import torch +import torch.nn as nn +from loguru import logger +from text_generation_server.layers.marlin.util import ( + _check_marlin_kernels, + marlin_zero_points, + permute_scales, + unpack_cols, +) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +GPTQ_MARLIN_BITS = [4, 8] +GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] +MARLIN_TILE_SIZE = 16 + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + # We only suppord asymmetric quantization for AWQ. + and (sym or quant_method == "awq") + ) + + +class GPTQMarlinWeightsLoader(WeightsLoader): + """ + Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + if not self.sym: + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + if not self.sym: + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=False, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + if not self.sym: + if self.desc_act or self.groupsize == -1: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + else: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + else: + qzeros = None + + if self.quant_method == "awq": + g_idx = None + else: + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + qzeros=qzeros, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + quant_method=self.quant_method, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" + + +@dataclass +class GPTQMarlinWeight(Weight): + """ + Repacked GPTQ Marlin weights. + """ + + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + bits: int + is_full_k: bool + + def __post_init__(self): + assert self.qweight.dtype == torch.int32 + assert self.scales.dtype == torch.float16 + assert self.g_idx.dtype == torch.int32 + assert self.perm.dtype == torch.int32 + + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + + +def repack_gptq_for_marlin( + *, + qweight: torch.Tensor, + qzeros: Optional[torch.Tensor], + scales: torch.Tensor, + g_idx: Optional[torch.Tensor], + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + sym: bool, + sharded_infeatures: bool, +) -> GPTQMarlinWeight: + """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" + _check_marlin_kernels() + assert marlin_kernels is not None + + if bits not in GPTQ_MARLIN_BITS: + supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + raise RuntimeError( + f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" + ) + + if groupsize not in GPTQ_MARLIN_GROUP_SIZES: + supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) + raise RuntimeError( + f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" + ) + if not (sym or quant_method == "awq"): + raise RuntimeError( + "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." + ) + + log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.") + + weights_per_int = 32 // bits + in_features = qweight.shape[0] + out_features = qweight.shape[1] + + # AWQ uses column packing, GPTQ uses row packing + if quant_method == "awq": + out_features *= weights_per_int + else: + in_features *= weights_per_int + + if in_features % groupsize != 0: + raise ValueError( + f"Number of input features ({in_features}) not divisible by group size ({groupsize})" + ) + + if g_idx is not None and desc_act and groupsize != -1: + perm = torch.argsort(g_idx).to(torch.int) + g_idx = g_idx[perm] + else: + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) + + if quant_method == "awq": + repacked = marlin_kernels.awq_marlin_repack( + qweight, in_features, out_features, bits + ) + if qzeros is not None: + qzeros = awq_to_marlin_zero_points( + qzeros, + in_features // groupsize, + out_features, + bits, + ) + + else: + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, bits + ) + + if qzeros is None: + qzeros = torch.empty(0, dtype=torch.int, device=qweight.device) + + scales = permute_scales(scales) + + is_full_k = not (desc_act and sharded_infeatures) + + return GPTQMarlinWeight( + qweight=repacked, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + perm=perm, + bits=bits, + is_full_k=is_full_k, + ) + + +class GPTQMarlinLinear(nn.Module): + """ + Linear layer for GPTQ weights that were converted for the GPTQ-Marlin + kernels. + """ + + def __init__( + self, + *, + weight: GPTQMarlinWeight, + bias: Optional[torch.Tensor], + ): + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE + out_features = weight.scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.bits = weight.bits + self.is_full_k = weight.is_full_k + + self.qweight = weight.qweight + self.qzeros = weight.qzeros + self.scales = weight.scales + self.g_idx = weight.g_idx + self.perm = weight.perm + if bias is not None: + self.bias = bias + else: + self.bias = None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.gptq_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.perm, + self.workspace, + self.bits, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + self.is_full_k, + self.qzeros.numel() > 0, + True, + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def _check_valid_shape(in_features: int, out_features: int): + if (in_features % 128 != 0 or out_features % 64 != 0) and ( + in_features % 64 != 0 or out_features % 128 != 0 + ): + raise ValueError( + f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." + " The shape elements must be divisible by (128, 64) or (64, 128)." + ) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin/marlin.py similarity index 57% rename from server/text_generation_server/layers/marlin.py rename to server/text_generation_server/layers/marlin/marlin.py index ecb88e76e..89ebaca62 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -1,28 +1,16 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union -from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn - -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.marlin.util import _check_marlin_kernels +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels except ImportError: marlin_kernels = None -try: - major, _minor = torch.cuda.get_device_capability() - has_sm_8_0 = major >= 8 -except Exception: - has_sm_8_0 = False - - -GPTQ_MARLIN_BITS = [4, 8] -GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] -MARLIN_TILE_SIZE = 16 - class MarlinWeightsLoader(WeightsLoader): """Loader for Marlin-quantized weights.""" @@ -31,6 +19,35 @@ class MarlinWeightsLoader(WeightsLoader): self.bits = bits self.is_marlin_24 = is_marlin_24 + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + def get_weights_col_packed( self, weights: Weights, @@ -61,15 +78,14 @@ class MarlinWeightsLoader(WeightsLoader): return weight def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = torch.cat( [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) B_meta = torch.cat( @@ -88,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader): ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) s = torch.cat( [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 @@ -99,8 +115,7 @@ class MarlinWeightsLoader(WeightsLoader): return weight def get_weights_row(self, weights: Weights, prefix: str): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = weights.get_sharded(f"{prefix}.B_24", dim=0) except RuntimeError: @@ -138,206 +153,73 @@ class MarlinWeightsLoader(WeightsLoader): return weight -def can_use_gptq_marlin( - *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool -) -> bool: - return ( - SYSTEM == "cuda" - and marlin_kernels is not None - and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and bits in GPTQ_MARLIN_BITS - and groupsize in GPTQ_MARLIN_GROUP_SIZES - and sym - ) - - -def _check_marlin_kernels(): - if not (SYSTEM == "cuda" and has_sm_8_0): - raise NotImplementedError( - "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." - ) - - if marlin_kernels is None: - raise NotImplementedError( - "marlin is not installed, install it with: pip install server/marlin" - ) - - -def _check_valid_shape(in_features: int, out_features: int): - if (in_features % 128 != 0 or out_features % 64 != 0) and ( - in_features % 64 != 0 or out_features % 128 != 0 - ): - raise ValueError( - f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})." - " The shape elements must be divisible by (128, 64) or (64, 128)." - ) - - -# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 -def _get_perms() -> Tuple[List[int], List[int]]: - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -_scale_perm, _scale_perm_single = _get_perms() - - -def permute_scales(scales: torch.Tensor): - out_features = scales.shape[1] - if scales.shape[0] == 1: - scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] - else: - scales = scales.reshape((-1, len(_scale_perm)))[:, _scale_perm] - return scales.reshape((-1, out_features)).contiguous() - - @dataclass -class GPTQMarlinWeight: +class MarlinWeight(Weight): """ - Repacked GPTQ Marlin weights. + Marlin weights. + + Attributes: + B (torch.Tensor): int4-quantized weights packed into int32. + s (torch.Tensor): bfloat16/float16 scales. """ - qweight: torch.Tensor - scales: torch.Tensor - g_idx: torch.Tensor - perm: torch.Tensor - bits: int - is_full_k: bool + B: torch.Tensor + s: torch.Tensor def __post_init__(self): - assert self.qweight.dtype == torch.int32 - assert self.scales.dtype == torch.float16 - assert self.g_idx.dtype == torch.int32 - assert self.perm.dtype == torch.int32 + assert self.B.dtype == torch.int32 + assert self.s.dtype in [torch.float16, torch.bfloat16] + + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) -def repack_gptq_for_marlin( - *, - qweight: torch.Tensor, - scales: torch.Tensor, - g_idx: torch.Tensor, - bits: int, - desc_act: bool, - groupsize: int, - sym: bool, - sharded_infeatures: bool, -) -> GPTQMarlinWeight: - """Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.""" - _check_marlin_kernels() - assert marlin_kernels is not None - - if bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) - raise RuntimeError( - f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}" - ) - - if groupsize not in GPTQ_MARLIN_GROUP_SIZES: - supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES) - raise RuntimeError( - f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" - ) - if not sym: - raise RuntimeError( - "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." - ) - - weights_per_int = 32 // bits - in_features = qweight.shape[0] * weights_per_int - out_features = qweight.shape[1] - - if in_features % groupsize != 0: - raise ValueError( - f"Number of input features ({in_features}) not divisible by group size ({groupsize})" - ) - - if desc_act and groupsize != -1: - perm = torch.argsort(g_idx).to(torch.int) - g_idx = g_idx[perm] - else: - perm = torch.empty(0, dtype=torch.int, device=qweight.device) - g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) - - repacked = marlin_kernels.gptq_marlin_repack( - qweight, perm, in_features, out_features, bits - ) - - scales = permute_scales(scales) - - is_full_k = not (desc_act and sharded_infeatures) - - return GPTQMarlinWeight( - qweight=repacked, - scales=scales, - g_idx=g_idx, - perm=perm, - bits=bits, - is_full_k=is_full_k, - ) - - -class GPTQMarlinLinear(nn.Module): - """ - Linear layer for GPTQ weights that were converted for the GPTQ-Marlin - kernels. - """ - - def __init__( - self, - *, - weight: GPTQMarlinWeight, - bias: Optional[torch.Tensor], - ): +class MarlinLinear(nn.Module): + def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): super().__init__() _check_marlin_kernels() assert marlin_kernels is not None - in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE - out_features = weight.scales.shape[1] - _check_valid_shape(in_features=in_features, out_features=out_features) + in_features = weight.B.shape[0] * MARLIN_TILE_SIZE + out_features = weight.s.shape[1] + assert ( + in_features % 128 == 0 + ), f"Number of input features ({in_features}) not divisable by 128" + assert ( + out_features % 256 == 0 + ), f"Number of output features ({out_features}) not divisable by 256" - self.bits = weight.bits - self.is_full_k = weight.is_full_k + groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] + assert groupsize in { + -1, + 128, + }, f"Group size must be -1 or 128, was {groupsize}" - self.qweight = weight.qweight - self.scales = weight.scales - self.g_idx = weight.g_idx - self.perm = weight.perm + self.B = weight.B + self.s = weight.s if bias is not None: self.bias = bias else: self.bias = None self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device + out_features // 64 * 16, dtype=torch.int, device=weight.B.device ) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None - A_flat = A.view(-1, A.shape[-1]) - C = marlin_kernels.gptq_marlin_gemm( - A_flat, - self.qweight, - self.scales, - self.g_idx, - self.perm, + C = marlin_kernels.marlin_gemm( + A.view(-1, A.shape[-1]), + self.B, + self.s, self.workspace, - self.bits, - A_flat.shape[0], - self.scales.shape[1], - A_flat.shape[1], - self.is_full_k, + A.shape[0], + self.s.shape[1], + A.shape[1], ) - C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) if self.bias is not None: C += self.bias @@ -350,6 +232,7 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_TILE_SIZE = 16 @dataclass @@ -374,6 +257,12 @@ class GPTQMarlin24Weight: assert self.B_meta.dtype == torch.int16 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + class GPTQMarlin24Linear(nn.Module): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): @@ -382,8 +271,10 @@ class GPTQMarlin24Linear(nn.Module): _check_marlin_kernels() assert marlin_kernels is not None - if weight.bits not in GPTQ_MARLIN_BITS: - supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS) + if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + supported_bits = ", ".join( + str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + ) raise RuntimeError( f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}" ) @@ -453,74 +344,3 @@ class GPTQMarlin24Linear(nn.Module): C += self.bias return C - - -@dataclass -class MarlinWeight: - """ - Marlin weights. - - Attributes: - B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): float16 scales. - """ - - B: torch.Tensor - s: torch.Tensor - - def __post_init__(self): - assert self.B.dtype == torch.int32 - assert self.s.dtype == torch.float16 - - -class MarlinLinear(nn.Module): - def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): - super().__init__() - - _check_marlin_kernels() - assert marlin_kernels is not None - - in_features = weight.B.shape[0] * MARLIN_TILE_SIZE - out_features = weight.s.shape[1] - assert ( - in_features % 128 == 0 - ), f"Number of input features ({in_features}) not divisable by 128" - assert ( - out_features % 256 == 0 - ), f"Number of output features ({out_features}) not divisable by 256" - - groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0] - assert groupsize in { - -1, - 128, - }, f"Group size must be -1 or 128, was {groupsize}" - - self.B = weight.B - self.s = weight.s - if bias is not None: - self.bias = bias - else: - self.bias = None - - self.workspace = torch.zeros( - out_features // 64 * 16, dtype=torch.int, device=weight.B.device - ) - - def forward(self, A: torch.Tensor) -> torch.Tensor: - assert marlin_kernels is not None - - C = marlin_kernels.marlin_gemm( - A.view(-1, A.shape[-1]), - self.B, - self.s, - self.workspace, - A.shape[0], - self.s.shape[1], - A.shape[1], - ) - C = C.reshape(A.shape[:-1] + (self.s.shape[1],)) - - if self.bias is not None: - C += self.bias - - return C diff --git a/server/text_generation_server/layers/marlin/util.py b/server/text_generation_server/layers/marlin/util.py new file mode 100644 index 000000000..250d17141 --- /dev/null +++ b/server/text_generation_server/layers/marlin/util.py @@ -0,0 +1,141 @@ +import functools +from typing import List, Tuple + +import numpy +import torch +from text_generation_server.utils.import_utils import SYSTEM + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +def _check_marlin_kernels(): + if not (SYSTEM == "cuda" and has_sm_8_0): + raise NotImplementedError( + "Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later." + ) + + if marlin_kernels is None: + raise NotImplementedError( + "marlin is not installed, install it with: pip install server/marlin" + ) + + +# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 +@functools.cache +def get_perms() -> Tuple[List[int], List[int]]: + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def permute_scales(scales: torch.Tensor): + scale_perm, scale_perm_single = get_perms() + out_features = scales.shape[1] + if scales.shape[0] == 1: + scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + else: + scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm] + return scales.reshape((-1, out_features)).contiguous() + + +# Functions below are from vLLM + + +def get_pack_factor(bits: int) -> int: + if 32 % bits != 0: + raise ValueError(f"Cannot {bits} bit values into uint32") + return 32 // bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + scale_perm, _ = get_perms() + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 87a61e82a..fc4a59b94 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,11 +1,10 @@ import os +import math import torch from torch import nn - from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": - from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops @@ -84,9 +83,13 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor = None rope_scaling = _get_rope_config(config) if rope_scaling is not None: - if rope_scaling["type"] == "linear": + # `rope_type` is now standard in transformers, but some existing models + # have `type` instead. + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) + + if rope_type == "linear": pass - elif rope_scaling["type"] == "dynamic": + elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( dim=dim, @@ -95,22 +98,39 @@ class PositionRotaryEmbedding(nn.Module): device=inv_freq.device, scaling_factor=scaling_factor, ) - elif rope_scaling["type"] == "yarn": + elif rope_type == "llama3": + inv_freq = apply_llama3_scaling( + inv_freq, + scaling_factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + ) + + return cls(inv_freq, scaling_factor) + + elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], - base=10000.0, + base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) - elif rope_scaling["type"] in ["su", "longrope"]: + elif rope_type in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) @@ -181,6 +201,8 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -193,6 +215,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( @@ -318,10 +342,6 @@ class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): self._sin_cached = torch.sin(freqs).to(dtype) -# Inverse dim formula to find dim based on number of rotations -import math - - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) @@ -346,10 +366,10 @@ def linear_ramp_mask(min, max, dim): return ramp_func -def get_mscale(scale=1): +def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): @@ -365,6 +385,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): attn_factor, beta_fast, beta_slow, + mscale: float, + mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) @@ -375,8 +397,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -387,7 +413,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - if seqlen > self.max_position_embeddings: + if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) @@ -400,6 +426,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.base, self.max_position_embeddings, ) + inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation @@ -409,9 +436,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ) self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) @@ -421,3 +445,33 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + + +def apply_llama3_scaling( + freqs: torch.Tensor, + *, + scaling_factor: int, + low_freq_factor: int, + high_freq_factor: int, + original_max_position_embeddings: int, +): + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scaling_factor) + else: + + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) diff --git a/server/text_generation_server/layers/speculative.py b/server/text_generation_server/layers/speculative.py index 4b977a56a..cf8469b53 100644 --- a/server/text_generation_server/layers/speculative.py +++ b/server/text_generation_server/layers/speculative.py @@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module): except KeyError: try: speculator = MedusaHeadV1.load(config, prefix, weights) - except: + except Exception: speculator = MedusaHeadV2(config, prefix, weights) lm_head = None else: diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 011f105b6..13f12ef1e 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -2,7 +2,6 @@ import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear -from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": @@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer): # If the piece and LM head embeddings are shared, we have # non-quantized weights... weight = weights.get_tensor(f"{prefix}.weight") - except: + except Exception: # ...otherwise they are quantized. weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 @@ -67,17 +66,8 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq", "marlin"]: - quantize = None - # See above, exl2 LM head can be quantized or not. - elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): - quantize = None - else: - quantize = config.quantize - return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), + get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) @@ -134,7 +124,7 @@ class TensorParallelColumnLinear(SuperLayer): raise NotImplementedError("packed_gate_up only implemented without bias") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -157,7 +147,7 @@ class TensorParallelColumnLinear(SuperLayer): raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -167,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -177,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer): for prefix in prefixes: weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None - linears.append(get_linear(weight, b, config.quantize)) + linears.append(get_linear(weight, b)) linear = LayerConcat(linears) else: weight = weights.get_multi_weights_col(prefixes, dim=dim) @@ -186,7 +176,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @@ -205,7 +195,7 @@ class TensorParallelRowLinear(SuperLayer): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias), process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba980195d..3dc24159d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,6 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables + import torch import enum import os @@ -6,7 +9,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List +from typing import Optional, List, Dict from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -33,7 +36,18 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( T5ForConditionalGeneration, ) + +from text_generation_server.utils.adapter import ( + AdapterParameters, + build_layer_weight_lookup, + load_and_merge_adapters, + AdapterInfo, +) +from text_generation_server.adapters.lora import LoraWeights + + from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,11 +61,9 @@ torch.set_grad_enabled(False) __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", - "get_model", + "get_model_with_lora_adapters", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -61,6 +73,10 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -121,7 +137,7 @@ try: ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -133,7 +149,7 @@ MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -141,6 +157,11 @@ if MAMBA_AVAILABLE: class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -298,10 +319,34 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq", "exl2"}: + log_master(logger.info, f"Auto selecting quantization method {method}") + quantize = method + elif method == "fbgemm_fp8": + log_master(logger.info, "Auto selecting quantization method fp8") + quantize = "fp8" + else: + log_master(logger.warning, f"Unknown quantization method {method}") + if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE + + if FBGEMM_DYN_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -318,11 +363,6 @@ def get_model( else: set_speculate(0) - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config_dict.get("model_type", None) - speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -424,7 +464,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -435,14 +477,6 @@ def get_model( raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" ) - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None and quantize is None: - method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") - quantize = method - else: - logger.info(f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -459,7 +493,40 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - if model_type == MAMBA: + if model_type == DEEPSEEK_V2: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MAMBA: return Mamba( model_id, revision, @@ -551,7 +618,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, @@ -573,6 +640,10 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + return FlashCausalLM( model_id=model_id, model_class=FlashGPTNeoXForCausalLM, @@ -582,6 +653,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, ) elif sharded: return CausalLM( @@ -643,6 +715,7 @@ def get_model( ) elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: + print(f">>> model_type: {model_type}") if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -787,7 +860,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashCausalLM( @@ -1056,3 +1129,116 @@ def get_model( ) raise ValueError(f"Unsupported model type {model_type}") + + +# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters +# this provides a post model loading hook to load adapters into the model after the model has been loaded +def get_model_with_lora_adapters( + model_id: str, + lora_adapters: Optional[List[AdapterInfo]], + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + max_input_tokens: int, + adapter_to_index: Dict[str, int], +): + lora_adapter_ids = [adapter.id for adapter in lora_adapters] + model = get_model( + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + max_input_tokens, + ) + + if len(lora_adapters) > 0: + target_to_layer = build_layer_weight_lookup(model.model) + + for index, adapter in enumerate(lora_adapters): + # The AdapterParameters object allows for merging multiple adapters into a single adapter. + # At the moment, we only support loading a single adapter into the model, but we keep the + # AdapterParameters object for easier extension in the future. + adapter_parameters = AdapterParameters( + adapter_info=[adapter], + # when merging multiple adapters we can weight them differently + # if this is not set, all adapters will be weighted equally + # see: text_generation_server.utils.merges.strategies for impl + weights=None, + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + + adapter_index = index + 1 + adapter_to_index[adapter.id] = adapter_index + + logger.info( + f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}" + ) + weight_names = tuple([v[0] for v in target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + model.model_id, + adapter_parameters, + adapter_index, + weight_names, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + + adapter_layers = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + for layer_name in adapter_layers: + nlayers = ( + 1 if layer_name == "lm_head" else len(model.model.model.layers) + ) + adapter_weights = LoraWeights.prepare_weights( + config=adapter_config, + module_map=module_map, + layer_type=layer_name, + unused_weight_names=unused_weight_names, + nlayers=nlayers, + dtype=model.dtype, + world_size=model.world_size, + process_group=model.process_group, + target_to_layer=target_to_layer, + ) + + if adapter_weights is None: + continue + + model.layer_to_adapter_weights[layer_name].add_adapter( + adapter_index, adapter_weights + ) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + model.loaded_adapters.add(adapter_index) + + return model diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0ea82b1e5..212ab7a90 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -233,7 +233,7 @@ class CausalLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection @@ -377,7 +377,7 @@ class CausalLMBatch(Batch): # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values @@ -492,7 +492,7 @@ class CausalLMBatch(Batch): @dataclass -class CausalLMBatchKeysLast(Batch): +class CausalLMBatchKeysLast(CausalLMBatch): keys_head_dim_last: bool = False @@ -544,7 +544,12 @@ class CausalLM(Model): config.quantize = quantize config.speculator = speculator if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = config.pad_token_id + if config.pad_token_id is not None: + tokenizer.pad_token_id = config.pad_token_id + elif config.eos_token_id is not None: + tokenizer.pad_token_id = config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id torch.distributed.barrier(group=self.process_group) weights_loader = get_loader( diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 77b89c5bf..e2719fad2 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel): loss = None if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] + output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return ( diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 27b9ff1cc..ab824da5b 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import nn @@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, ) from transformers.modeling_outputs import ( - BaseModelOutput, BaseModelOutputWithPooling, - ImageClassifierOutput, ) from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig @@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module): class CLIPTextTransformer(nn.Module): - def __init__(self, prefix: str, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig, weights=None): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = CLIPTextEmbeddings(config) + # Initialize weights and apply final processing with `self.post_init()` self.encoder = CLIPEncoder( prefix=f"{prefix}.encoder", config=config, weights=weights ) @@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module): # text_embeds.shape = [batch_size, sequence_length, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ + last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), @@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module): ] else: # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) - pooled_output = last_hidden_state[ + last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device ), @@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) return self.text_model( input_ids=input_ids, @@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module): def __init__(self, prefix, config: CLIPVisionConfig, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights @@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel): >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) return self.vision_model( pixel_values=pixel_values, @@ -799,14 +791,12 @@ class CLIPModel(nn.Module): # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. vision_outputs = self.vision_model( pixel_values=pixel_values, - return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - return_dict=return_dict, ) image_embeds = vision_outputs[1] diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 25719b999..e02a31d9a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,7 +30,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -45,6 +44,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight if SYSTEM == "cuda": import dropout_layer_norm @@ -83,6 +83,12 @@ class CohereRotary(PositionRotaryEmbedding): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -99,7 +105,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": + if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) @@ -166,16 +172,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.attention_bias: w = [ @@ -186,9 +192,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class FlashCohereAttention(torch.nn.Module): @@ -259,8 +263,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) @@ -287,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -305,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 444116870..d3d1d1efc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.log import log_once class DbrxAttentionConfig(PretrainedConfig): @@ -247,10 +246,10 @@ def _load_experts_quantized(config, prefix, weights, cls): if cls == TensorParallelRowLinear: expert_slice = expert_slice.t().contiguous() - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear, weights.process_group)) else: - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear)) return experts @@ -331,17 +330,13 @@ class DbrxAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -349,7 +344,6 @@ class DbrxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 000000000..0905d3c29 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,981 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention.common import Seqlen +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def _load_experts(config, prefix: str, mat: str, weights: Weights): + if config.quantize is not None: + raise NotImplementedError( + "Deepseek V2 does not support weight quantization yet." + ) + + assert mat in ["gate_proj", "up_proj", "down_proj"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.moe_intermediate_size % world_size == 0 + ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.moe_intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.n_routed_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.n_routed_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "down_proj": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + key, + value, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DeepseekV2Config, weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + gate_proj = _load_experts( + config, f"{prefix}.experts", "gate_proj", weights + ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + + up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( + self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim + ) + + self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) + + self.down_proj = ( + _load_experts(config, f"{prefix}.experts", "down_proj", weights) + .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + out = ( + fused_experts( + x, + self.gate_up_proj, + self.down_proj, + topk_weights, + topk_ids, + inplace=True, + ) + * self.routed_scaling_factor + ) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + # + # Seems like no one quantizes the gate. + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.experts = [ + DeepseekV2MLP( + f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size + ) + for i in range(self.n_routed_experts) + ] + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + # gate_logits: (sequence_length, n_experts) + router_logits = self.gate(x) + + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def moe_infer_gpu( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ): + weights = torch.zeros( + topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device + ) + weights.scatter_(1, topk_ids, topk_weight) + + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i, expert in enumerate(self.experts): + # Add expert output to out with masking + out += expert(x, reduce=False) * weights[:, i].view(-1, 1) + return out + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +# Functions below are from vLLM: +# +# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 +# +# Remove after we have synced our version with upstream. + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a3ce55213..de86f5149 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class Gemma2Config(PretrainedConfig): @@ -144,20 +145,18 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemma2Attention(torch.nn.Module): @@ -190,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module): self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) + self.softcap = config.attn_logit_softcapping self.query_key_value = load_attention(config, prefix, weights) @@ -231,27 +231,23 @@ class FlashGemma2Attention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, causal=self.causal, window_size_left=self.window_size, + softcap=self.softcap, ) # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -260,6 +256,7 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + softcap=self.softcap, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -467,6 +464,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): config=config, weights=weights, ) + self.softcap = config.final_logit_softcapping + assert isinstance(self.softcap, float) def forward( self, @@ -496,4 +495,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) + + logits /= self.softcap + logits = torch.tanh(logits) + logits *= self.softcap + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 34a7efa25..178efadbe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class GemmaConfig(PretrainedConfig): @@ -144,20 +145,18 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemmaAttention(torch.nn.Module): @@ -226,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -245,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index cbfcb1b8f..a19cff8cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights): bias = torch.cat(tensors, dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def _load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): 3 * num_heads * head_size ], f"{weight.shape} != {[3 * num_heads * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool): bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) @@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool): else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) class FlashGPT2Attention(torch.nn.Module): @@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341c..9ea19a87d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional, Tuple import torch @@ -25,7 +26,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -46,6 +45,10 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import ( + Weights, +) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -105,6 +108,19 @@ def load_attention(config, prefix: str, weights, layer_id): ) +@contextmanager +def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" + weights_loader = weights.weights_loader + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) + + with weights.use_loader(weights_loader): + yield + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -197,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -215,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -260,7 +271,7 @@ class LlamaMLP(nn.Module): bias=bias, ) else: - prefixes = [f"gate_proj", f"up_proj"] + prefixes = ["gate_proj", "up_proj"] sizes = [ config.intermediate_size, config.intermediate_size, @@ -330,12 +341,15 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.self_attn = FlashLlamaAttention( - index=index, - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - ) + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) @@ -396,7 +410,22 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else f"{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -408,9 +437,26 @@ class FlashLlamaModel(torch.nn.Module): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, @@ -470,23 +516,27 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" - self.lm_head = SpeculativeHead.load( - config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, - ) + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8028dbe80..dda53ff3d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( - Seqlen, paged_attention, attention, reshape_and_cache, @@ -38,7 +37,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) @@ -117,7 +115,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -146,15 +147,14 @@ class MistralAttention(torch.nn.Module): bias=False, ) - head_size = config.hidden_size // config.num_attention_heads self.query_key_value = TensorParallelMultiAdapterLinear.load( query_key_value, layer_id, ["q_proj", "k_proj", "v_proj"], sizes=[ - head_size * config.num_attention_heads, - head_size * config.num_key_value_heads, - head_size * config.num_key_value_heads, + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, ], process_group=weights.process_group, ) @@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 49c0e9030..85431c6c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -21,7 +21,6 @@ import torch import torch.distributed -import numpy as np from torch import nn from text_generation_server.utils.import_utils import SYSTEM @@ -31,7 +30,6 @@ if SYSTEM != "ipex": from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from loguru import logger from text_generation_server.layers.attention import ( paged_attention, @@ -52,6 +50,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class MixtralConfig(PretrainedConfig): @@ -138,20 +137,18 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) def _load_experts(config, prefix: str, mat, weights): @@ -272,17 +269,13 @@ class MixtralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -291,7 +284,6 @@ class MixtralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 85dcb2a6c..b1b03ad75 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.models.gpt_neox import GPTNeoXConfig +from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( @@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GPTNeoXConfig(TransformersGPTNeoXConfig): + attribute_map = { + "num_key_value_heads": "num_attention_heads", + } def load_row(config, prefix: str, weights, bias: bool): @@ -56,7 +63,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: @@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_multi_weights_col([prefix], dim=0) - if isinstance(weight, torch.Tensor): + if isinstance(weight, UnquantizedWeight): # Only on non quantized versions - weight = ( - weight.view( + weight.weight = ( + weight.weight.view( num_heads, 3, head_size, @@ -81,7 +88,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: @@ -151,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module): reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(qkv[:, 0]) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -169,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, qkv[:, 0], kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 1f998e5a8..d10efb411 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -16,7 +16,6 @@ import torch import torch.distributed from torch import nn -from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6c5082642..a9e183480 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights): ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" # this is the same as llama except for Phi uses bias=True - return TensorParallelColumnLinear( - get_linear(weight, bias=True, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=True)) class FlashPhiAttention(torch.nn.Module): @@ -190,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module): # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -207,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index a98709c51..865cc85de 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -15,7 +15,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -131,17 +130,13 @@ class Qwen2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -150,7 +145,6 @@ class Qwen2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -262,6 +256,9 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() + + prefix = f"{prefix}.model" if prefix else "model" + process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -335,15 +332,16 @@ class Qwen2ForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = Qwen2Model(prefix, config, weights) + + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + self.lm_head = SpeculativeHead.load( config, - prefix="lm_head", + prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) self.max_past = config.sliding_window diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 65b40fed6..708641e7f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.parallel_attn: return linear else: @@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module): slots, ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 77b9d49cd..c26767822 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -105,6 +105,7 @@ def _load_multi_mqa_gptq( g_idx=g_idx, bits=loader.bits, groupsize=loader.groupsize, + use_awq_kernel=loader.quantize == "awq", use_exllama=HAS_EXLLAMA, ) @@ -121,7 +122,7 @@ def _load_multi_mqa_gptq( bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -193,7 +194,7 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_col(config, prefix: str, weights, bias: bool): @@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) @@ -285,17 +286,13 @@ class FlashMQAttention(torch.nn.Module): key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -303,7 +300,6 @@ class FlashMQAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 19556f783..e562eb896 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class Starcoder2Config(PretrainedConfig): @@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.use_bias: w = [ @@ -149,9 +150,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class Starcoder2Attention(torch.nn.Module): @@ -236,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -255,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a83bc1c64..7e4deaf88 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch Idefics2 model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.utils.checkpoint @@ -22,10 +22,8 @@ from torch import nn import math from transformers.activations import ACT2FN -from transformers.image_processing_utils import select_best_resolution from text_generation_server.models.custom_modeling.vlm import ( load_text_model, - load_vision_model, ) from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -34,6 +32,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -682,7 +681,7 @@ class Idefics2Connector(nn.Module): class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - config.vision_config.quantize = config.quantize + config.vision_config.quantize = None config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator @@ -695,16 +694,24 @@ class Idefics2ForConditionalGeneration(nn.Module): name="text_model", ) self.dtype = weights.dtype - self.vision_model = Idefics2VisionTransformer( - prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", - config=vision_config, - weights=weights, - ) - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + self.config = config self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_token_id = config.image_token_id @@ -823,6 +830,7 @@ class Idefics2ForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index e323d365c..afb8e1f9c 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -19,6 +19,7 @@ import numpy as np from PIL import Image +import transformers from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( resize, @@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor): return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) -import transformers - transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 786ef559b..fc6becc4b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -21,10 +21,8 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.activations import ACT2FN @@ -33,13 +31,6 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, dataclass, ) -from transformers.modeling_utils import PretrainedConfig -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_vision import ( IdeficsVisionTransformer, @@ -56,6 +47,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.import_utils import SYSTEM +from loguru import logger if SYSTEM == "cuda": import dropout_layer_norm @@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module): prefix="model.embed_tokens", weights=weights ) self.additional_weight = nn.Parameter( - weights.get_tensor(f"model.embed_tokens.additional_embedding.weight") + weights.get_tensor("model.embed_tokens.additional_embedding.weight") ) def forward(self, input_ids): @@ -359,7 +351,19 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + residual is not None, + ) + return out + elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states @@ -499,7 +503,6 @@ class IdeficsAttention(nn.Module): # if not hasattr(nn.functional, "scaled_dot_product_attention"): # raise ValueError("this model requires pytorch 2.0 or higher") - process_group = weights.process_group if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -1024,7 +1027,7 @@ class IdeficsModel(IdeficsPreTrainedModel): if config.use_resampler: perceiver_config = config.perceiver_config self.perceiver_resampler = IdeficsPerceiverResampler( - prefix=f"model.perceiver_resampler", + prefix="model.perceiver_resampler", config=config, embed_dim=config.vision_config.embed_dim, depth=perceiver_config.resampler_depth, @@ -1052,7 +1055,7 @@ class IdeficsModel(IdeficsPreTrainedModel): # self.gradient_checkpointing = False self.norm = IdeficsRMSNorm( - prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps + prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) # self.gradient_checkpointing = False diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py index af44490b3..6da8045bc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module): self.qk_scale = self.head_dim**-0.5 - process_group = weights.process_group if n_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " diff --git a/server/text_generation_server/models/custom_modeling/idefics_processing.py b/server/text_generation_server/models/custom_modeling/idefics_processing.py index 7bba6977e..ca61e27d4 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_processing.py @@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import ( TruncationStrategy, ) from transformers.utils import TensorType, is_torch_available -from text_generation_server.models.custom_modeling.idefics_image_processing import ( - IdeficsImageProcessor, -) if is_torch_available(): diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index 30c5997fe..dd8f76bc4 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module): self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - process_group = weights.process_group if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = IdeficsVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 567131ef7..29f5b9c71 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.utils.checkpoint @@ -280,6 +280,7 @@ class LlavaNextForConditionalGeneration(nn.Module): max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index fb09a8f17..988a74a39 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math -import os import warnings from typing import List, Optional, Tuple, Union import torch @@ -75,7 +74,7 @@ def load_col(config, prefix, weights, bias): bias = bias.to(device=weights.device) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return TensorParallelColumnLinear(linear) @@ -194,7 +193,7 @@ def flash_attn_fn( ): try: from flash_attn import bert_padding, flash_attn_interface - except: + except Exception: raise RuntimeError("Please install flash-attn==1.0.3.post0") check_valid_inputs(query, key, value) if past_key_value is not None: @@ -207,7 +206,7 @@ def flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: - raise NotImplementedError(f"attn_bias not implemented for flash attn.") + raise NotImplementedError("attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) @@ -269,13 +268,13 @@ def triton_flash_attn_fn( ): try: from .flash_attn_triton import flash_attn_func - except: + except Exception: _installed = False if version.parse(torch.__version__) < version.parse("2.0.0"): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func - except: + except Exception: _installed = False if not _installed: raise RuntimeError( @@ -292,9 +291,9 @@ def triton_flash_attn_fn( _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: - raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") + raise NotImplementedError("Dropout not implemented for attn_impl: triton.") if needs_weights: - raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") + raise NotImplementedError("attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: warnings.warn( "Propagating key_padding_mask to the attention module " @@ -337,17 +336,17 @@ class MultiheadAttention(nn.Module): weights, ): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop if self.n_heads % weights.process_group.size() != 0: raise ValueError( @@ -428,24 +427,24 @@ class MultiQueryAttention(nn.Module): additive bias. """ - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix, weights, verbose=False): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) - fuse_splits = (d_model, d_model + self.head_dim) + (d_model, d_model + self.head_dim) if self.qk_ln: raise NotImplementedError("qk_ln not supported") if self.attn_impl == "flash": @@ -614,9 +613,9 @@ class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix - if config.attn_config["attn_type"] != "multihead_attention": + if config.attn_config.attn_type != "multihead_attention": raise NotImplementedError( - f"""Not implemented attn {config.attn_config["attn_type"]}""" + f"""Not implemented attn {config.attn_config.attn_type}""" ) resid_pdrop = config.resid_pdrop if config.no_bias: @@ -789,13 +788,15 @@ class MPTModel(MPTPreTrainedModel): self.world_size = weights.process_group.size() self.rank = weights.process_group.rank() self.n_heads = config.n_heads - self.attn_impl = config.attn_config["attn_impl"] - self.prefix_lm = config.attn_config["prefix_lm"] - self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] - self.alibi = config.attn_config["alibi"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.attn_impl = config.attn_config.attn_impl + self.prefix_lm = config.attn_config.prefix_lm + self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id + self.alibi = config.attn_config.alibi + self.alibi_bias_max = config.attn_config.alibi_bias_max if config.init_device == "mixed": - if dist.get_local_rank() == 0: + # TODO: reimplement mixed device initialization + # dist.get_local_rank() == 0: + if True: config.init_device = "cpu" else: config.init_device = "meta" @@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel): if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( - f"past_key_values must provide a past_key_value for each attention " + "past_key_values must provide a past_key_value for each attention " + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." ) past_position = past_key_values[0][0].size(1) @@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel): input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") == False: + if kwargs.get("use_cache") is False: raise NotImplementedError( "MPT with prefix_lm=True does not support use_cache=False." ) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 8998778fd..06731a6f9 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -21,25 +21,14 @@ import torch import torch.distributed import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers import GPTNeoXConfig -from loguru import logger from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) - max_positions = config.max_position_embeddings # ??? TODO # self.register_buffer( # "bias", diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index b4d56db13..3f2ed010f 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -5,7 +5,7 @@ import torch.distributed import math from torch import nn -from typing import Optional, List, Tuple, Any +from typing import Optional, List, Tuple from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index 5fbc6d293..480d0f9f3 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -1,20 +1,15 @@ -from typing import Optional, Tuple, Union - +from typing import Optional, Tuple +import warnings import math import torch from torch import nn from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import ( - _create_4d_causal_attention_mask, - _prepare_4d_attention_mask, -) from transformers.modeling_outputs import ( - BaseModelOutput, BaseModelOutputWithPooling, - ImageClassifierOutput, ) -from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from transformers import SiglipConfig, SiglipVisionConfig +from torch.nn.init import _calculate_fan_in_and_fan_out from text_generation_server.layers.tensor_parallel import ( TensorParallelEmbedding, @@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -import warnings - - def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * lower - 1, 2 * upper - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -313,9 +305,6 @@ def trunc_normal_tf_( tensor.mul_(std).add_(mean) -from torch.nn.init import _calculate_fan_in_and_fan_out - - def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": @@ -349,9 +338,6 @@ def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") -from transformers import PreTrainedModel - - class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module): def __init__(self, prefix, config: SiglipVisionConfig, weights): super().__init__() self.config = config - embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings( prefix=f"{prefix}.embeddings", config=config, weights=weights diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 0b899fba1..e6666acd3 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -45,6 +45,15 @@ from text_generation_server.layers import ( SpeculativeHead, ) +# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + class PartialTPEmbedding(nn.Module): def __init__(self, prefix: str, weights): @@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) # move labels to correct device to enable PP - labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + output = (logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return ( diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index b74b43ff9..e5c44045a 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights): ) return SiglipVisionTransformer( - prefix=f"vision_tower.vision_model", config=config, weights=weights + prefix="vision_tower.vision_model", config=config, weights=weights ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca9eef33..36bb26621 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,7 +1,6 @@ import math import os import time -import itertools import torch import torch.distributed @@ -23,14 +22,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -45,7 +43,6 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, ) from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -839,7 +836,9 @@ class FlashCausalLM(Model): default_dtype=torch.float16, aliases=None, # Used for Santacoder override of config - num_kv_heads=None, + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -922,7 +921,16 @@ class FlashCausalLM(Model): else num_kv_heads ) assert self.num_kv_heads > 0 - self.head_size = config.hidden_size // config.num_attention_heads + + if head_size is None: + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size self.cuda_graphs = {} self.kv_cache = [] @@ -1147,42 +1155,49 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: - logger.exception(f"Decode cuda graph warmup failed") + logger.exception("Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1534,8 +1549,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): @@ -1684,72 +1698,3 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py deleted file mode 100644 index 2b2bd2e04..000000000 --- a/server/text_generation_server/models/flash_mistral.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM - - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashMistral(FlashCausalLM): - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - # TODO: this is a hack to avoid the gate_proj for - # FlashStarcoder2 that doesnt have these layers - if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 2d43244a0..7c4e462c7 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -2,23 +2,15 @@ import re import torch import torch.distributed -from typing import List, Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, - initialize_torch_distributed, - weight_files, - Weights, ) from text_generation_server.utils.chunks import concat_text_chunks diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccdd..8d2431dbc 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,23 +27,11 @@ else: if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs -# This is overridden at model loading. -global MODEL_ID -MODEL_ID = None - - -def set_model_id(model_id: str): - global MODEL_ID - MODEL_ID = model_id - - # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 0deab6ce2..29929b982 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -1,13 +1,8 @@ import torch import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional -from transformers import ( - AutoTokenizer, - AutoConfig, - AutoProcessor, -) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_processing import ( @@ -25,6 +20,8 @@ from text_generation_server.utils import ( ) from text_generation_server.utils.quantization import get_loader +from text_generation_server.utils.import_utils import SYSTEM + class IDEFICSSharded(IdeficsCausalLM): def __init__( @@ -42,6 +39,14 @@ class IDEFICSSharded(IdeficsCausalLM): # 9b seems to work correctly enough in float16, but 80b seems # to be really saturating for f16. dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 6c5629808..8a80ed682 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch): image_hidden_states = self.image_hidden_states[keep_indices] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection @@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch): # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 4ed9722c4..5d6ce3c76 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -2,7 +2,6 @@ import torch import torch.distributed from transformers import AutoTokenizer, PreTrainedTokenizerBase from typing import Optional -import os from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) @@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import ( InferenceParams, ) from text_generation_server.models import Model -from typing import Any, List, Optional, Tuple, Type, Dict +from typing import Any, List, Tuple, Type, Dict from text_generation_server.models.types import ( Batch, Tokens, @@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import NextTokenChooser, StoppingCriteria def new_inference_params( @@ -299,7 +298,6 @@ class MambaBatch(Batch): stopping_criterias = [] top_n_tokens = [] max_tokens = 0 - max_seqlen = 0 seqlen_offset = 0 (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape @@ -485,7 +483,7 @@ class Mamba(Model): for bs in CUDA_GRAPHS: self.cuda_graph_warmup(bs) except Exception: - logger.exception(f"Decode cuda graph warmup failed") + logger.exception("Decode cuda graph warmup failed") else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") @@ -534,7 +532,7 @@ class Mamba(Model): } self.cuda_graphs[batch_size] = graph_dict - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, batch_size: int, seqlen: int): input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) n_blocks = len(self.model.blocks) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b859..20402e07a 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,21 +2,14 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict +from typing import List, Tuple, Optional, TypeVar, Type, Dict from collections import defaultdict -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights -from text_generation_server.utils.adapter import ( - load_and_merge_adapters, - AdapterParameters, - AdapterSource, -) -from loguru import logger - BASE_MODEL_ADAPTER_ID = "__base_model__" @@ -60,7 +53,6 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -141,138 +133,3 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) - - @property - def supports_adapter_loading(self) -> bool: - return False - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - return {} - - @property - def adapter_layers(self) -> List[str]: - return [] - - @property - def default_traced_adapter_layers(self) -> List[str]: - return [] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 0 - - def is_row_parallel(self, layer_type: str) -> bool: - return False - - @property - def max_speculative_tokens(self) -> int: - return max( - [ - weights.max_speculative_tokens - for weights in self.layer_to_adapter_weights.values() - ], - default=0, - ) - - def load_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - api_token: str, - dynamic: bool = True, - ): - """Loads adapter weights from disk / host memory on the GPU. - - adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are applied during the forward - pass and stored separately from the base model parameters. - """ - if self.target_to_layer is None: - self.target_to_layer = self.adapter_target_to_layer() - if adapter_index in self.loaded_adapters: - # Adapter already loaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if dynamic and not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" - ) - weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - ( - module_map, - adapter_config, - adapter_weight_names, - adapter_tokenizer, - ) = load_and_merge_adapters( - self.model_id, - adapter_parameters, - adapter_source, - adapter_index, - weight_names, - api_token, - False, - ) - - unused_weight_names = adapter_weight_names.copy() - for layer_name in self.adapter_layers: - adapter_weights = adapter_config.load_batched_adapter_weights( - self, - module_map, - layer_name, - unused_weight_names, - dynamic, - ) - - if adapter_weights is None: - continue - - layer_weights = self.layer_to_adapter_weights[layer_name] - layer_weights.add_adapter(adapter_index, adapter_weights) - - if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" - ) - - if adapter_tokenizer is not None: - self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) - - self.loaded_adapters.add(adapter_index) - - def offload_adapter( - self, - adapter_parameters: AdapterParameters, - adapter_source: AdapterSource, - adapter_index: int, - ): - """Offloads the adapter weights from GPU to CPU or disk.""" - if adapter_index not in self.loaded_adapters: - # Adapter already offloaded - return - - if not self.supports_adapter_loading: - raise ValueError("This model does not support adapter loading.") - - if not self.dynamic_adapter_loading_enabled: - raise ValueError( - f"This model was initialized with the adapter {self.static_adapter_id} " - f"and therefore does not support dynamic adapter loading. " - f"Please initialize a new model instance from the base model in " - f"order to use the dynamic adapter loading feature." - ) - - for layer_name in self.adapter_layers: - if layer_name in self.layer_to_adapter_weights: - self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) - - self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index 3994ac707..fe75570ea 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -3,16 +3,11 @@ from PIL import Image import torch import torch.distributed from opentelemetry import trace -from typing import Iterable, Optional, Tuple +from typing import Iterable from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLM, VlmCausalLMBatch, image_text_replacement, ) -from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( - PaliGemmaForConditionalGeneration, -) -from transformers import AutoProcessor, AutoConfig from text_generation_server.pb.generate_pb2 import Request diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index fa8b50256..79c001b0b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,7 +1,6 @@ import torch import torch.distributed import time - from dataclasses import dataclass from opentelemetry import trace from transformers import ( @@ -11,7 +10,7 @@ from transformers import ( AutoConfig, ) from typing import Optional, Tuple, List, Type, Dict - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: + if type(self.past_key_values[0]) is tuple: self.past_key_values = [ [t for t in layer] for layer in self.past_key_values ] @@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch): batch.encoder_last_hidden_state = None # Ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: + if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ [t for t in layer] for layer in batch.past_key_values ] diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 339b733b5..d4e7cca75 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -1,4 +1,3 @@ -from functools import total_ordering import torch from abc import ABC, abstractmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ace488058..7de54aa44 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,7 +12,9 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor +from text_generation_server.layers.attention import Seqlen tracer = trace.get_tracer(__name__) @@ -56,8 +57,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features @@ -261,7 +263,12 @@ class VlmCausalLM(FlashCausalLM): **processor_kwargs, ) self.batch_class = batch_class - super().__init__(model_id=model_id, **kwargs) + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **kwargs, + ) @property def batch_type(self) -> Type[VlmCausalLMBatch]: @@ -342,6 +349,7 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aee287c67..b92ab572a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -13,7 +13,8 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models import Model, get_model +from text_generation_server.models import Model, get_model_with_lora_adapters +from text_generation_server.utils.adapter import AdapterInfo try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -29,10 +30,7 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id, set_adapter_to_index -from text_generation_server.utils.adapter import ( - AdapterParameters, -) +from text_generation_server.models.globals import set_adapter_to_index class SignalHandler: @@ -195,7 +193,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -207,7 +205,7 @@ def serve( ): async def serve_inner( model_id: str, - lora_adapter_ids: Optional[List[str]], + lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -228,9 +226,9 @@ def serve( server_urls = [local_url] try: - model = get_model( + model = get_model_with_lora_adapters( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, @@ -238,29 +236,9 @@ def serve( dtype, trust_remote_code, max_input_tokens, + adapter_to_index, ) - if len(lora_adapter_ids) > 0: - for index, adapter_id in enumerate(lora_adapter_ids): - # TODO: improve non merged adapter loading and long term - # improve adapter loading as a whole - adapter_parameters = AdapterParameters( - adapter_ids=[adapter_id], - weights=None, # will be set to 1 - merge_strategy=0, - density=1.0, - majority_sign_method=0, - ) - adapter_index = index + 1 - adapter_to_index[adapter_id] = adapter_index - model.load_adapter( - adapter_parameters, - None, # adapter_source - adapter_index, - None, # api_token - False, # dynamic - ) - except Exception: logger.exception("Error when initializing model") raise @@ -293,11 +271,10 @@ def serve( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) - set_model_id(model_id) asyncio.run( serve_inner( model_id, - lora_adapter_ids, + lora_adapters, revision, sharded, quantize, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 4e2492de0..1db5f77b6 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -5,12 +5,11 @@ import warnings from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple, Optional, List from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from text_generation_server.pb import generate_pb2 from text_generation_server.utils.merges.strategies import merge_adapters from text_generation_server.utils import hub @@ -24,9 +23,15 @@ if TYPE_CHECKING: BASE_MODEL_ADAPTER_ID = "__base_model__" +@dataclass +class AdapterInfo: + id: str + path: Optional[str] + + @dataclass class AdapterParameters: - adapter_ids: Tuple[str] + adapter_info: Tuple[AdapterInfo] weights: Tuple[float] merge_strategy: NotImplemented density: float @@ -40,37 +45,47 @@ class AdapterSource: revision: str +def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]: + if not lora_adapters: + return [] + + adapter_list = [] + for adapter in lora_adapters.split(","): + parts = adapter.strip().split("=") + if len(parts) == 1: + adapter_list.append(AdapterInfo(id=parts[0], path=None)) + elif len(parts) == 2: + adapter_list.append(AdapterInfo(id=parts[0], path=parts[1])) + else: + raise ValueError(f"Invalid LoRA adapter format: {adapter}") + return adapter_list + + def load_and_merge_adapters( model_id: str, adapter_parameters: AdapterParameters, - adapter_source: str, adapter_index: int, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - if len(adapter_parameters.adapter_ids) == 1: + + if len(adapter_parameters.adapter_info) == 1: + adapter_info = next(iter(adapter_parameters.adapter_info)) return load_module_map( model_id, - adapter_parameters.adapter_ids[0], - adapter_source, + adapter_info.id, + adapter_info.path, weight_names, - api_token, trust_remote_code, ) - adapter_params = AdapterParametersContainer( - adapter_parameters, adapter_source, adapter_index - ) - return _load_and_merge( - model_id, adapter_params, weight_names, api_token, trust_remote_code - ) + adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) + return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code) @dataclass class AdapterParametersContainer: adapter_parameters: AdapterParameters - adapter_source: str adapter_index: int def __hash__(self) -> int: @@ -82,7 +97,6 @@ def _load_and_merge( model_id: str, adapter_params: AdapterParametersContainer, weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: params = adapter_params.adapter_parameters @@ -90,17 +104,16 @@ def _load_and_merge( adapters_to_merge = [] merged_weight_names = set() tokenizer = None - for adapter_id in params.adapter_ids: - if adapter_id == BASE_MODEL_ADAPTER_ID: + for adapter in params.adapter_info: + if adapter.id == BASE_MODEL_ADAPTER_ID: raise ValueError("Base model adapter cannot be merged.") module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( load_module_map( model_id, - adapter_id, - adapter_params.adapter_source, + adapter.id, + adapter.path, weight_names, - api_token, trust_remote_code, ) ) @@ -159,25 +172,28 @@ def check_architectures( def load_module_map( model_id: str, adapter_id: str, - adapter_source: str, + adapter_path: Optional[str], weight_names: Tuple[str], - api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: revision = "main" - adapter_config = LoraConfig.load(adapter_id, api_token) - if adapter_config.base_model_name_or_path != model_id: + adapter_config = LoraConfig.load(adapter_path or adapter_id, None) + + if not adapter_path and adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) - adapter_filenames = hub._cached_adapter_weight_files( - adapter_id, revision=revision, extension=".safetensors" + adapter_filenames = ( + hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") + if adapter_path + else hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) ) try: adapter_tokenizer = AutoTokenizer.from_pretrained( adapter_config.config_path, - token=api_token, trust_remote_code=trust_remote_code, ) except Exception: @@ -194,3 +210,87 @@ def load_module_map( adapter_weights, weight_names ) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer + + +def get_attn_weights(i, layer): + qkv = layer.self_attn.query_key_value + weights = {} + + for k in ["q", "k", "v"]: + key = (i, f"{k}_proj") + value = (f"model.layers.{i}.self_attn.{k}_proj", qkv) + weights[key] = value + + weights[(i, "o_proj")] = ( + f"model.layers.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + return weights + + +def get_mlp_weights(i, layer): + weights = {} + if hasattr(layer, "mlp"): + mlp = layer.mlp + if hasattr(mlp, "gate_up_proj"): + # handle combined gate_up_proj (e.g., for some LLaMA variants) + weights.update( + { + (i, "gate_proj"): ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_up_proj, + ), + (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj), + } + ) + else: + # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma) + if hasattr(mlp, "gate_proj"): + weights[(i, "gate_proj")] = ( + f"model.layers.{i}.mlp.gate_proj", + mlp.gate_proj, + ) + if hasattr(mlp, "up_proj"): + weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) + + if hasattr(mlp, "down_proj"): + weights[(i, "down_proj")] = ( + f"model.layers.{i}.mlp.down_proj", + mlp.down_proj, + ) + + return weights + + +# build_layer_weight_lookup creates a mapping of model layers to their corresponding +# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples +# containing the weight tensor path and the actual layer object. This mapping is needed +# for the lora adapter to know which weights to update when applying the adapter. +def build_layer_weight_lookup(model): + if hasattr(model, "language_model"): + m = model.language_model.model + elif hasattr(model, "text_model"): + m = model.text_model.model + else: + m = model.model + + layer_weights = {} + + for i, layer in enumerate(m.layers): + attn_weights = get_attn_weights(i, layer) + mlp_weights = get_mlp_weights(i, layer) + + layer_weights.update(attn_weights) + layer_weights.update(mlp_weights) + + lm_head = None + if hasattr(m, "lm_head"): + lm_head = m.lm_head + elif hasattr(model, "lm_head"): + lm_head = model.lm_head + + if lm_head: + layer_weights[(0, "lm_head")] = ("lm_head", lm_head) + + return layer_weights diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86d..82aeba6ce 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -56,7 +56,7 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) + options._timeout = timedelta(seconds=120) else: backend = "gloo" options = None @@ -76,7 +76,7 @@ def initialize_torch_distributed(): backend="ccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: @@ -84,7 +84,7 @@ def initialize_torch_distributed(): backend=backend, world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 011e0f635..7c0530143 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,15 +1,13 @@ import torch from loguru import logger -import subprocess import os +import importlib.util + + def is_ipex_available(): - try: - import intel_extension_for_pytorch - except ImportError: - return False - return True + return importlib.util.find_spec("intel_extension_for_pytorch") is not None def get_cuda_free_memory(device, memory_fraction): diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e5..4385c71ee 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 6b9154370..9abd886f2 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -521,7 +521,13 @@ class GrammarLogitProcessor(LogitsProcessor): def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - schema = build_regex_from_schema(schema) + try: + schema = build_regex_from_schema(schema) + # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer + except Exception as e: + logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") + # allows everything + schema = "(.*?)" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer) diff --git a/server/text_generation_server/utils/merges/strategies.py b/server/text_generation_server/utils/merges/strategies.py index 3b8853136..cb39cde1f 100644 --- a/server/text_generation_server/utils/merges/strategies.py +++ b/server/text_generation_server/utils/merges/strategies.py @@ -2,9 +2,17 @@ import copy from abc import ABC from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union - +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) import torch +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + class AdapterParameters: def __init__( @@ -17,17 +25,6 @@ class AdapterParameters: self.majority_sign_method = majority_sign_method -from text_generation_server.utils.merges.utils import ( - calculate_majority_sign_mask, - disjoint_merge, - prune, -) - -if TYPE_CHECKING: - from text_generation_server.adapters.lora import LoraConfig - from text_generation_server.utils.adapter import ModuleMap - - def _apply_weights( tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor ) -> torch.Tensor: diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 0ea89267b..d49e73f00 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): low_cpu_mem_usage=True, ) logger.info("Peft model detected.") - logger.info(f"Merging the lora weights.") + logger.info("Merging the lora weights.") base_model_id = model.peft_config["default"].base_model_name_or_path diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 07975bea0..ee561acc4 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -1,13 +1,17 @@ -from typing import Optional -import os import json +import os from dataclasses import dataclass +from typing import Optional from huggingface_hub import hf_hub_download - -from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader +from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + WeightsLoader, +) +# TODO: Split this config to have a single config type per quant method @dataclass class _QuantizerConfig: bits: int @@ -18,6 +22,11 @@ class _QuantizerConfig: sym: bool +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + # We should probably do this with Pytantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): @@ -25,7 +34,7 @@ def _get_quantizer_config(model_id, revision): groupsize = -1 quant_method = "gptq" checkpoint_format = None - sym = True + sym = False desc_act = False filename = "config.json" @@ -36,12 +45,24 @@ def _get_quantizer_config(model_id, revision): filename = hf_hub_download(model_id, filename=filename, revision=revision) with open(filename, "r") as f: data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + + if "zero_point" in data["quantization_config"]: + sym = not data["quantization_config"]["zero_point"] + quant_method = "awq" + elif "sym" in data["quantization_config"]: + sym = data["quantization_config"]["sym"] + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - sym = data["quantization_config"]["sym"] desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" @@ -56,7 +77,13 @@ def _get_quantizer_config(model_id, revision): data = json.load(f) bits = data["bits"] groupsize = data["group_size"] - sym = data["sym"] + + if "zero_point" in data: + sym = not data["zero_point"] + quant_method = "awq" + elif "sym" in data: + sym = data["sym"] + desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": quant_method = "awq" @@ -96,14 +123,54 @@ def get_loader( if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader - return GPTQWeightsLoader( + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + + if can_use_gptq_marlin( bits=quantizer_config.bits, - desc_act=quantizer_config.desc_act, groupsize=quantizer_config.groupsize, quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, - ) + ): + from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader + + return GPTQMarlinWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + else: + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "bitsandbytes": + from text_generation_server.layers.bnb import BNBWeight + + return DefaultWeightsLoader(BNBWeight) + elif quantize == "bitsandbytes-fp4": + from text_generation_server.layers.bnb import BNBFP4Weight + + return DefaultWeightsLoader(BNBFP4Weight) + elif quantize == "bitsandbytes-nf4": + from text_generation_server.layers.bnb import BNBNF4Weight + + return DefaultWeightsLoader(BNBNF4Weight) + elif quantize == "eetq": + from text_generation_server.layers.eetq import EETQWeight + + return DefaultWeightsLoader(EETQWeight) elif quantize == "exl2": from text_generation_server.layers.exl2 import Exl2WeightsLoader @@ -111,9 +178,25 @@ def get_loader( elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return MarlinWeightsLoader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") else: - return DefaultWeightsLoader() + raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 22f86b60c..9ab49665a 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,7 +1,6 @@ import re from typing import List, Optional, Tuple, Set, Union -import math import torch from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 1a62fb3bd..108ced480 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,9 +1,14 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Dict, List, Optional, Union -from safetensors import safe_open import torch +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Optional, Union, Type +from safetensors import safe_open +from dataclasses import dataclass + +from text_generation_server.utils.import_utils import SYSTEM + class WeightsLoader(ABC): """ @@ -16,6 +21,13 @@ class WeightsLoader(ABC): with the format, etc. """ + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + @abstractmethod def get_weights_col_packed( self, @@ -61,28 +73,68 @@ class WeightsLoader(ABC): ... +class Weight(ABC): + """Instances of this type implement unquantized/quantized/to-be + quantized weights.""" + + @abstractmethod + def get_linear(self, bias: torch.Tensor): + """Create a linear layer from this weight.""" + ... + + +@dataclass +class UnquantizedWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.linear import FastLinear, FastLinearROCm + + if SYSTEM == "rocm": + return FastLinearROCm(self.weight, bias) + else: + return FastLinear(self.weight, bias) + + class DefaultWeightsLoader(WeightsLoader): + """Weight loader that loads (unquantized) Torch tensors.""" + + def __init__(self, weight_class: Type[UnquantizedWeight]): + """Create a loader. Weights will be wrapped using the given `weights_class`, + normally this will be `UnquantizedWeight`, but a quantizer-specific class + such as `Fp8Weight` can be used to quantize the weights during loading. + """ + self.weight_class = weight_class + """ Loader that uses tensors as-is with the exception of applying sharding and/or concatenation. """ + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + def get_weights_col_packed( self, weights: "Weights", prefix: str, block_sizes: Union[int, List[int]], ): - return weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes + + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), ) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - return torch.cat(w, dim=dim) + return self.weight_class(torch.cat(w, dim=dim)) def get_weights_row(self, weights: "Weights", prefix: str): - return weights.get_sharded(f"{prefix}.weight", dim=1) + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) class Weights: @@ -146,23 +198,41 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded( + self, tensor_name: str, dim: int, to_device=True, to_dtype=True + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -182,12 +252,17 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -196,10 +271,16 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded( + tensor_name, dim, to_device=to_device, to_dtype=to_dtype + ) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -245,15 +326,26 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) return tensor + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + def get_weights_col_packed_qkv( self, prefix: str, - quantize: str, num_heads: int, num_key_value_heads: int, ): @@ -299,6 +391,20 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + @contextmanager + def use_loader(self, weights_loader: WeightsLoader): + """ + This method is a context manager that can be used to use `Weights` with + a different loader for the duration of the context. + """ + + old_loader = self.weights_loader + self.weights_loader = weights_loader + try: + yield + finally: + self.weights_loader = old_loader + def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """ diff --git a/update_doc.py b/update_doc.py index bfa7e4e9d..428d44521 100644 --- a/update_doc.py +++ b/update_doc.py @@ -167,22 +167,24 @@ def check_openapi(check: bool): else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - errors = subprocess.run( + p = subprocess.run( [ - "swagger-cli", + "redocly", # allow for trailing whitespace since it's not significant # and the precommit hook will remove it - "validate", + "lint", filename, ], capture_output=True, - ).stderr.decode("utf-8") + ) + errors = p.stderr.decode("utf-8") # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 - if not errors.startswith("Swagger schema validation failed."): + print(errors) + if p.returncode != 0: print(errors) raise Exception( - f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}" ) return True