mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into add-mistral-nemo
This commit is contained in:
commit
f0a5cb6c4e
27
.github/workflows/build.yaml
vendored
27
.github/workflows/build.yaml
vendored
@ -27,8 +27,8 @@ jobs:
|
|||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
# TODO see with @Glegendre to get CPU runner here instead
|
runs-on:
|
||||||
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
|
group: aws-highmemory-32-plus-priv
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -49,7 +49,7 @@ jobs:
|
|||||||
export dockerfile="Dockerfile"
|
export dockerfile="Dockerfile"
|
||||||
export label_extension=""
|
export label_extension=""
|
||||||
export docker_devices=""
|
export docker_devices=""
|
||||||
export runs_on="nvidia-gpu"
|
export runs_on="aws-g5-12xlarge-plus"
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
export dockerfile="Dockerfile_amd"
|
export dockerfile="Dockerfile_amd"
|
||||||
@ -79,9 +79,15 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
with:
|
||||||
install: true
|
install: true
|
||||||
config-inline: |
|
buildkitd-config-inline: |
|
||||||
[registry."docker.io"]
|
[registry."docker.io"]
|
||||||
mirrors = ["registry.github-runners.huggingface.tech"]
|
mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"]
|
||||||
|
- name: Login to internal Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
registry: registry.internal.huggingface.tech
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
@ -103,7 +109,8 @@ jobs:
|
|||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference
|
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||||
# If main, release or tag
|
# If main, release or tag
|
||||||
@ -115,7 +122,8 @@ jobs:
|
|||||||
flavor: |
|
flavor: |
|
||||||
latest=auto
|
latest=auto
|
||||||
images: |
|
images: |
|
||||||
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference
|
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca
|
||||||
ghcr.io/huggingface/text-generation-inference
|
ghcr.io/huggingface/text-generation-inference
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
@ -141,7 +149,7 @@ jobs:
|
|||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||||
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
|
||||||
@ -150,7 +158,8 @@ jobs:
|
|||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: build-and-push
|
needs: build-and-push
|
||||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
runs-on:
|
||||||
|
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||||
env:
|
env:
|
||||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
||||||
|
3
.github/workflows/load_test.yaml
vendored
3
.github/workflows/load_test.yaml
vendored
@ -15,7 +15,8 @@ jobs:
|
|||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
|
runs-on:
|
||||||
|
group: aws-g5-12xlarge
|
||||||
env:
|
env:
|
||||||
DOCKER_VOLUME: /cache
|
DOCKER_VOLUME: /cache
|
||||||
steps:
|
steps:
|
||||||
|
53
Cargo.lock
generated
53
Cargo.lock
generated
@ -801,6 +801,27 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe"
|
||||||
|
dependencies = [
|
||||||
|
"csv-core",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv-core"
|
||||||
|
version = "0.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ctrlc"
|
name = "ctrlc"
|
||||||
version = "3.4.4"
|
version = "3.4.4"
|
||||||
@ -3402,9 +3423,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.118"
|
version = "1.0.120"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4"
|
checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
@ -3650,15 +3671,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sysinfo"
|
name = "sysinfo"
|
||||||
version = "0.30.12"
|
version = "0.30.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae"
|
checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"core-foundation-sys",
|
"core-foundation-sys",
|
||||||
"libc",
|
"libc",
|
||||||
"ntapi",
|
"ntapi",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"rayon",
|
||||||
"windows",
|
"windows",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -3805,6 +3827,7 @@ dependencies = [
|
|||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap",
|
"clap",
|
||||||
|
"csv",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
@ -3826,6 +3849,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"sysinfo",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
@ -3837,6 +3861,7 @@ dependencies = [
|
|||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"utoipa",
|
"utoipa",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
|
"uuid",
|
||||||
"vergen",
|
"vergen",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -4508,9 +4533,25 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uuid"
|
name = "uuid"
|
||||||
version = "1.9.1"
|
version = "1.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439"
|
checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
"rand",
|
||||||
|
"uuid-macro-internal",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uuid-macro-internal"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.68",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "v_frame"
|
name = "v_frame"
|
||||||
|
17
Dockerfile
17
Dockerfile
@ -161,6 +161,17 @@ COPY server/custom_kernels/ .
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
# Build FBGEMM CUDA kernels
|
||||||
|
FROM kernel-builder AS fbgemm-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-fbgemm Makefile
|
||||||
|
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
|
||||||
|
COPY server/fix_torch90a.sh fix_torch90a.sh
|
||||||
|
|
||||||
|
RUN make build-fbgemm
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
# Build vllm CUDA kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
|
|
||||||
@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
|||||||
# Copy build artifacts from marlin kernels builder
|
# Copy build artifacts from marlin kernels builder
|
||||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from fbgemm builder
|
||||||
# Copy builds artifacts from vllm builder
|
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||||
|
# Copy build artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# Legacy warning ⚠️
|
||||||
|
The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.
|
||||||
|
|
||||||
# Text Generation
|
# Text Generation
|
||||||
|
|
||||||
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||||
|
@ -909,7 +909,7 @@
|
|||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/ToolType"
|
"$ref": "#/components/schemas/ToolChoice"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"nullable": true
|
"nullable": true
|
||||||
@ -2035,6 +2035,14 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ToolChoice": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ToolType"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"ToolType": {
|
"ToolType": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
@ -2055,6 +2063,11 @@
|
|||||||
"$ref": "#/components/schemas/FunctionName"
|
"$ref": "#/components/schemas/FunctionName"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"default": null,
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -21,6 +21,8 @@
|
|||||||
title: Messages API
|
title: Messages API
|
||||||
- local: architecture
|
- local: architecture
|
||||||
title: Internal Architecture
|
title: Internal Architecture
|
||||||
|
- local: usage_statistics
|
||||||
|
title: Usage Statistics
|
||||||
title: Getting started
|
title: Getting started
|
||||||
- sections:
|
- sections:
|
||||||
- local: basic_tutorials/consuming_tgi
|
- local: basic_tutorials/consuming_tgi
|
||||||
|
@ -424,6 +424,22 @@ Options:
|
|||||||
|
|
||||||
[env: LORA_ADAPTERS=]
|
[env: LORA_ADAPTERS=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## DISABLE_USAGE_STATS
|
||||||
|
```shell
|
||||||
|
--disable-usage-stats
|
||||||
|
Disable sending of all usage statistics
|
||||||
|
|
||||||
|
[env: DISABLE_USAGE_STATS=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## DISABLE_CRASH_REPORTS
|
||||||
|
```shell
|
||||||
|
--disable-crash-reports
|
||||||
|
Disable sending of crash reports, but allow anonymous usage statistics
|
||||||
|
|
||||||
|
[env: DISABLE_CRASH_REPORTS=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HELP
|
## HELP
|
||||||
```shell
|
```shell
|
||||||
|
@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||||
|
73
docs/source/usage_statistics.md
Normal file
73
docs/source/usage_statistics.md
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
|
||||||
|
# Collection of Usage Statistics
|
||||||
|
|
||||||
|
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
|
||||||
|
|
||||||
|
Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine.
|
||||||
|
|
||||||
|
## What data is collected
|
||||||
|
|
||||||
|
The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs).
|
||||||
|
As of release 2.1.2 this is an example of the data collected:
|
||||||
|
|
||||||
|
- From the TGI configuration:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"event_type": "start",
|
||||||
|
"disable_grammar_support": false,
|
||||||
|
"max_batch_prefill_tokens": 4096,
|
||||||
|
"max_batch_size": null,
|
||||||
|
"max_batch_total_tokens": null,
|
||||||
|
"max_best_of": 2,
|
||||||
|
"max_client_batch_size": 4,
|
||||||
|
"max_concurrent_requests": 128,
|
||||||
|
"max_input_tokens": 1024,
|
||||||
|
"max_stop_sequences": 4,
|
||||||
|
"max_top_n_tokens": 5,
|
||||||
|
"max_total_tokens": 2048,
|
||||||
|
"max_waiting_tokens": 20,
|
||||||
|
"messages_api_enabled": false,
|
||||||
|
"model_config": {
|
||||||
|
"model_type": "Bloom"
|
||||||
|
},
|
||||||
|
"revision": null,
|
||||||
|
"tokenizer_class": "BloomTokenizerFast",
|
||||||
|
"validation_workers": 2,
|
||||||
|
"waiting_served_ratio": 1.2,
|
||||||
|
"docker_label": "latest",
|
||||||
|
"git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe",
|
||||||
|
"nvidia_env": {
|
||||||
|
"name": "NVIDIA A10G",
|
||||||
|
"pci_bus_id": "00000000:00:1E.0",
|
||||||
|
"driver_version": "535.183.01",
|
||||||
|
"pstate": "P8",
|
||||||
|
"pcie_link_gen_max": "4",
|
||||||
|
"pcie_link_gen_current": "1",
|
||||||
|
"temperature_gpu": "31",
|
||||||
|
"utilization_gpu": "0 %",
|
||||||
|
"utilization_memory": "0 %",
|
||||||
|
"memory_total": "23028 MiB",
|
||||||
|
"memory_free": "22515 MiB",
|
||||||
|
"memory_used": "0 MiB",
|
||||||
|
"reset_status_reset_required": "No",
|
||||||
|
"reset_status_drain_and_reset_recommended": "No",
|
||||||
|
"compute_cap": "8.6",
|
||||||
|
"ecc_errors_corrected_volatile_total": "0",
|
||||||
|
"mig_mode_current": "[N/A]",
|
||||||
|
"power_draw_instant": "10.86 W",
|
||||||
|
"power_limit": "300.00 W"
|
||||||
|
},
|
||||||
|
"system_env": {
|
||||||
|
"cpu_count": 16,
|
||||||
|
"cpu_type": "AMD EPYC 7R32",
|
||||||
|
"total_memory": 66681196544,
|
||||||
|
"architecture": "x86_64",
|
||||||
|
"platform": "linux-unix-x86_64"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to opt-out
|
||||||
|
|
||||||
|
You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics.
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.1875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": -1.5546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 549,
|
||||||
|
"logprob": -2.84375,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -2.34375,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -0.8359375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 317,
|
||||||
|
"logprob": -1.0859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -1.5390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -1.1875,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3458,
|
||||||
|
"logprob": -0.35546875,
|
||||||
|
"special": false,
|
||||||
|
"text": " step"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -0.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.71484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nThe test request is the first step in the"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.1875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2143,
|
||||||
|
"logprob": -1.828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " sent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10081,
|
||||||
|
"logprob": -0.36914062,
|
||||||
|
"special": false,
|
||||||
|
"text": " successfully"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1380,
|
||||||
|
"logprob": -0.38671875,
|
||||||
|
"special": false,
|
||||||
|
"text": "We"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 543,
|
||||||
|
"logprob": -0.12695312,
|
||||||
|
"special": false,
|
||||||
|
"text": " will"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 752,
|
||||||
|
"logprob": -0.20117188,
|
||||||
|
"special": false,
|
||||||
|
"text": " get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5402,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " touch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 366,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " with"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request sent successfully.\nWe will get in touch with"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.1875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": -1.5546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 549,
|
||||||
|
"logprob": -2.8125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -2.375,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -0.890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 317,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.5390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -2.609375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.75,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.1171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -0.90625,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nThe test request is a request for a test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.25,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": -1.5546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 549,
|
||||||
|
"logprob": -2.8125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -2.375,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -0.890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 317,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.5390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -2.609375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.75,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.1171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -0.90625,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nThe test request is a request for a test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.25,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": -1.5546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 549,
|
||||||
|
"logprob": -2.8125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -2.375,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -0.890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 317,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.5390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -2.609375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.75,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.1171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -0.90625,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nThe test request is a request for a test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 100000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin▁of▁sentence|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3533,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -11.25,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 185,
|
||||||
|
"logprob": -1.5546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 549,
|
||||||
|
"logprob": -2.8125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -2.375,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -0.890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 317,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.5390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3102,
|
||||||
|
"logprob": -2.609375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.75,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 245,
|
||||||
|
"logprob": -1.1171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1727,
|
||||||
|
"logprob": -0.90625,
|
||||||
|
"special": false,
|
||||||
|
"text": " test"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nThe test request is a request for a test"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,254 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<bos>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 106,
|
||||||
|
"logprob": -47.25,
|
||||||
|
"text": "<start_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1645,
|
||||||
|
"logprob": -18.875,
|
||||||
|
"text": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -7.15625,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -4.78125,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5559,
|
||||||
|
"logprob": -10.0,
|
||||||
|
"text": "Write"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 476,
|
||||||
|
"logprob": -0.1171875,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19592,
|
||||||
|
"logprob": -2.46875,
|
||||||
|
"text": " poem"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 577,
|
||||||
|
"logprob": -5.84375,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1707,
|
||||||
|
"logprob": -6.375,
|
||||||
|
"text": " help"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 682,
|
||||||
|
"logprob": -2.125,
|
||||||
|
"text": " me"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5434,
|
||||||
|
"logprob": -1.546875,
|
||||||
|
"text": " remember"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.62890625,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1370,
|
||||||
|
"logprob": -6.65625,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235248,
|
||||||
|
"logprob": -1.84375,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235274,
|
||||||
|
"logprob": -0.45117188,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235276,
|
||||||
|
"logprob": -0.07421875,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6635,
|
||||||
|
"logprob": -2.109375,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 611,
|
||||||
|
"logprob": -0.4140625,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.0009536743,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26163,
|
||||||
|
"logprob": -0.033203125,
|
||||||
|
"text": " periodic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3037,
|
||||||
|
"logprob": -0.0002670288,
|
||||||
|
"text": " table"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235269,
|
||||||
|
"logprob": -4.75,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7385,
|
||||||
|
"logprob": -11.625,
|
||||||
|
"text": " giving"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1853,
|
||||||
|
"logprob": -4.875,
|
||||||
|
"text": " each"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5356,
|
||||||
|
"logprob": -0.38867188,
|
||||||
|
"text": " element"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1277,
|
||||||
|
"logprob": -3.65625,
|
||||||
|
"text": " its"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1997,
|
||||||
|
"logprob": -4.4375,
|
||||||
|
"text": " own"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2017,
|
||||||
|
"logprob": -0.29882812,
|
||||||
|
"text": " line"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235265,
|
||||||
|
"logprob": -0.16699219,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 107,
|
||||||
|
"logprob": -25.625,
|
||||||
|
"text": "<end_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -6.75,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 106,
|
||||||
|
"logprob": -39.5,
|
||||||
|
"text": "<start_of_turn>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2516,
|
||||||
|
"logprob": -32.5,
|
||||||
|
"text": "model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235292,
|
||||||
|
"logprob": -10.125,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -3.421875,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 688,
|
||||||
|
"logprob": -0.546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "**"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 103889,
|
||||||
|
"logprob": -0.49023438,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hydrogen"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 190213,
|
||||||
|
"logprob": -0.48632812,
|
||||||
|
"special": false,
|
||||||
|
"text": "**,"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2611,
|
||||||
|
"logprob": -0.58203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " light"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 578,
|
||||||
|
"logprob": -0.099121094,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2223,
|
||||||
|
"logprob": -1.078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " free"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 235269,
|
||||||
|
"logprob": -0.025756836,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.29101562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 688,
|
||||||
|
"logprob": -0.0035858154,
|
||||||
|
"special": false,
|
||||||
|
"text": "**"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1949,
|
||||||
|
"logprob": -4.1007996e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "He"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "**Hydrogen**, light and free,\n**He"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.421875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.546875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -2.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -3.6308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 679,
|
||||||
|
"logprob": -1.7900391,
|
||||||
|
"special": false,
|
||||||
|
"text": "201"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -1.3554688,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -2.0039062,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2366,
|
||||||
|
"logprob": -0.4489746,
|
||||||
|
"special": false,
|
||||||
|
"text": "202"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15,
|
||||||
|
"logprob": -0.037109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2978,
|
||||||
|
"logprob": -0.8100586,
|
||||||
|
"special": false,
|
||||||
|
"text": " school"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1060,
|
||||||
|
"logprob": -0.013015747,
|
||||||
|
"special": false,
|
||||||
|
"text": " year"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the 2019-2020 school year"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.5625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.375,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -0.8984375,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2209,
|
||||||
|
"logprob": -2.78125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -0.6328125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 734,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " function"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.34179688,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4110,
|
||||||
|
"logprob": -2.359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Create"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7575,
|
||||||
|
"logprob": -2.1875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Process"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.07910156,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -0.83203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12468,
|
||||||
|
"logprob": -1.8203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Win"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.421875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.546875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -2.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -3.6308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 679,
|
||||||
|
"logprob": -1.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": "201"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -1.3535156,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -2.0058594,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2366,
|
||||||
|
"logprob": -0.45410156,
|
||||||
|
"special": false,
|
||||||
|
"text": "202"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15,
|
||||||
|
"logprob": -0.037109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2978,
|
||||||
|
"logprob": -0.8095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " school"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1060,
|
||||||
|
"logprob": -0.013053894,
|
||||||
|
"special": false,
|
||||||
|
"text": " year"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the 2019-2020 school year"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.421875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.546875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -2.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -3.6308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 679,
|
||||||
|
"logprob": -1.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": "201"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -1.3535156,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -2.0058594,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2366,
|
||||||
|
"logprob": -0.45410156,
|
||||||
|
"special": false,
|
||||||
|
"text": "202"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15,
|
||||||
|
"logprob": -0.037109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2978,
|
||||||
|
"logprob": -0.8095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " school"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1060,
|
||||||
|
"logprob": -0.013053894,
|
||||||
|
"special": false,
|
||||||
|
"text": " year"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the 2019-2020 school year"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.421875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.546875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -2.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -3.6308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 679,
|
||||||
|
"logprob": -1.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": "201"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -1.3535156,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -2.0058594,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2366,
|
||||||
|
"logprob": -0.45410156,
|
||||||
|
"special": false,
|
||||||
|
"text": "202"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15,
|
||||||
|
"logprob": -0.037109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2978,
|
||||||
|
"logprob": -0.8095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " school"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1060,
|
||||||
|
"logprob": -0.013053894,
|
||||||
|
"special": false,
|
||||||
|
"text": " year"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the 2019-2020 school year"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2323,
|
||||||
|
"logprob": -9.421875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -10.546875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -2.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -3.6308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 679,
|
||||||
|
"logprob": -1.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": "201"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24,
|
||||||
|
"logprob": -1.3535156,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"logprob": -2.0058594,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2366,
|
||||||
|
"logprob": -0.45410156,
|
||||||
|
"special": false,
|
||||||
|
"text": "202"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15,
|
||||||
|
"logprob": -0.037109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2978,
|
||||||
|
"logprob": -0.8095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " school"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1060,
|
||||||
|
"logprob": -0.013053894,
|
||||||
|
"special": false,
|
||||||
|
"text": " year"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " for the 2019-2020 school year"
|
||||||
|
}
|
||||||
|
]
|
63
integration-tests/models/test_flash_deepseek_v2.py
Normal file
63
integration-tests/models/test_flash_deepseek_v2.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_deepseek_v2_handle(launcher):
|
||||||
|
with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_deepseek_v2(flash_deepseek_v2_handle):
|
||||||
|
await flash_deepseek_v2_handle.health(300)
|
||||||
|
return flash_deepseek_v2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot):
|
||||||
|
response = await flash_deepseek_v2.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot):
|
||||||
|
response = await flash_deepseek_v2.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_deepseek_v2_load(
|
||||||
|
flash_deepseek_v2, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_deepseek_v2, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
46
integration-tests/models/test_flash_gemma2.py
Normal file
46
integration-tests/models/test_flash_gemma2.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma2_handle(launcher):
|
||||||
|
with launcher("google/gemma-2-9b-it", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma2(flash_gemma2_handle):
|
||||||
|
await flash_gemma2_handle.health(300)
|
||||||
|
return flash_gemma2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma2(flash_gemma2, response_snapshot):
|
||||||
|
response = await flash_gemma2.generate(
|
||||||
|
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.generated_text == "**Hydrogen**, light and free,\n**He"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_gemma2,
|
||||||
|
"<start_of_turn>user:\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\n<start_of_turn>model:\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses[0].generated_text == "**Hydrogen**, light and free,\n**He"
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
62
integration-tests/models/test_flash_llama_fp8.py
Normal file
62
integration-tests/models/test_flash_llama_fp8.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_fp8_handle(launcher):
|
||||||
|
with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_fp8(flash_llama_fp8_handle):
|
||||||
|
await flash_llama_fp8_handle.health(300)
|
||||||
|
return flash_llama_fp8_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
|
||||||
|
response = await flash_llama_fp8.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
|
||||||
|
response = await flash_llama_fp8.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_fp8, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -25,6 +25,7 @@ mod env_runtime;
|
|||||||
struct RawConfig {
|
struct RawConfig {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
n_positions: Option<usize>,
|
n_positions: Option<usize>,
|
||||||
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
max_seq_len: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,6 +458,14 @@ struct Args {
|
|||||||
/// startup that will be available to callers via the `adapter_id` field in a request.
|
/// startup that will be available to callers via the `adapter_id` field in a request.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
lora_adapters: Option<String>,
|
lora_adapters: Option<String>,
|
||||||
|
|
||||||
|
/// Disable sending of all usage statistics
|
||||||
|
#[clap(default_value = "false", long, env)]
|
||||||
|
disable_usage_stats: bool,
|
||||||
|
|
||||||
|
/// Disable sending of crash reports, but allow anonymous usage statistics
|
||||||
|
#[clap(default_value = "false", long, env)]
|
||||||
|
disable_crash_reports: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -1201,6 +1210,14 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Pass usage stats flags to router
|
||||||
|
if args.disable_usage_stats {
|
||||||
|
router_args.push("--disable-usage-stats".to_string());
|
||||||
|
}
|
||||||
|
if args.disable_crash_reports {
|
||||||
|
router_args.push("--disable-crash-reports".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Grammar support
|
// Grammar support
|
||||||
if args.disable_grammar_support {
|
if args.disable_grammar_support {
|
||||||
router_args.push("--disable-grammar-support".to_string());
|
router_args.push("--disable-grammar-support".to_string());
|
||||||
@ -1402,6 +1419,11 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
if config.model_type == Some("gemma2".to_string()) {
|
||||||
|
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||||
|
std::env::set_var("FLASH_DECODING", "1");
|
||||||
|
}
|
||||||
let config: Config = config.into();
|
let config: Config = config.into();
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
|
@ -52,6 +52,10 @@ regex = "1.10.3"
|
|||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
image = "0.25.1"
|
image = "0.25.1"
|
||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
|
sysinfo = "0.30.13"
|
||||||
|
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||||
|
csv = "1.3.0"
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
|
@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
|
|||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||||
@ -332,126 +332,131 @@ impl ChatTemplate {
|
|||||||
pub struct ToolGrammar {}
|
pub struct ToolGrammar {}
|
||||||
|
|
||||||
impl ToolGrammar {
|
impl ToolGrammar {
|
||||||
|
// find a tool by name
|
||||||
|
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
tool_choice: Option<ToolType>,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<Option<Tools>, InferError> {
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
// if no tools are provided, we return None
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
let tools = match tools {
|
||||||
let tools_to_use = match tool_choice {
|
Some(tools) if !tools.is_empty() => tools,
|
||||||
ToolType::FunctionName(name) => {
|
_ => return Ok(None),
|
||||||
vec![req_tools
|
};
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
|
||||||
ToolType::Function { function } => {
|
|
||||||
let tool = req_tools
|
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == function.name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
|
||||||
.clone();
|
|
||||||
vec![tool]
|
|
||||||
}
|
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
.iter()
|
let tools_to_use = match tool_choice {
|
||||||
.map(|tool| {
|
ToolType::FunctionName(name) => {
|
||||||
let func = tool.function.clone();
|
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||||
|
}
|
||||||
|
ToolType::Function { function } => {
|
||||||
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => tools,
|
||||||
|
ToolType::NoTool => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
// adds the error notification function for LLM feedback if required
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
let mut text_response_properties = Map::new();
|
||||||
params.clone()
|
text_response_properties.insert(
|
||||||
} else {
|
"error".to_string(),
|
||||||
Map::new()
|
serde_json::json!({
|
||||||
};
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
text_response_properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": "notify_error"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
params.insert(
|
.iter()
|
||||||
"description".to_string(),
|
.map(|tool| {
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
let func = tool.function.clone();
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
// Clone the existing parameters, which are expected to be a JSON object
|
||||||
let properties = params
|
let mut params = if let Value::Object(params) = &func.arguments {
|
||||||
.entry("properties".to_string())
|
params.clone()
|
||||||
.or_insert_with(|| json!({}))
|
} else {
|
||||||
.as_object_mut()
|
Map::new()
|
||||||
.unwrap();
|
};
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
// Insert the function's description at the top level, outside of properties
|
||||||
properties.insert(
|
params.insert(
|
||||||
"_name".to_string(),
|
"description".to_string(),
|
||||||
json!({
|
Value::String(func.description.clone().unwrap_or_default()),
|
||||||
"type": "string",
|
);
|
||||||
"const": func.name.clone(),
|
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
// Ensure 'properties' exists and is an object
|
||||||
let required = params
|
let properties = params
|
||||||
.entry("required".to_string())
|
.entry("properties".to_string())
|
||||||
.or_insert_with(|| json!([]))
|
.or_insert_with(|| json!({}))
|
||||||
.as_array_mut()
|
.as_object_mut()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
// Insert the constant for the function name inside 'properties'
|
||||||
if !required.iter().any(|r| r == "_name") {
|
properties.insert(
|
||||||
required.push(json!("_name"));
|
"_name".to_string(),
|
||||||
}
|
json!({
|
||||||
|
"type": "string",
|
||||||
(func.name, Value::Object(params))
|
"const": func.name.clone(),
|
||||||
})
|
// "description": "The name of the function"
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
}),
|
||||||
)])
|
);
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||||
functions_map: FunctionsMap { functions },
|
let required = params
|
||||||
properties: Properties {
|
.entry("required".to_string())
|
||||||
function: tools_to_use
|
.or_insert_with(|| json!([]))
|
||||||
.iter()
|
.as_array_mut()
|
||||||
.map(|tool| FunctionRef {
|
.unwrap();
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Some(tools));
|
// Add 'name' to the 'required' array if it is not already present
|
||||||
}
|
if !required.iter().any(|r| r == "_name") {
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
required.push(json!("_name"));
|
||||||
Ok(None)
|
}
|
||||||
|
|
||||||
|
(func.name, Value::Object(params))
|
||||||
|
})
|
||||||
|
.chain([(
|
||||||
|
"notify_error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"properties": text_response_properties,
|
||||||
|
"required": ["error", "_name"],
|
||||||
|
"type": "object"
|
||||||
|
}),
|
||||||
|
)])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.chain(std::iter::once(FunctionRef {
|
||||||
|
ref_path: "#/$functions/notify_error".to_string(),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(tools))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,8 @@ mod validation;
|
|||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
mod kserve;
|
mod kserve;
|
||||||
|
|
||||||
|
pub mod usage_stats;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
@ -40,13 +42,13 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct ChatTemplate {
|
pub struct ChatTemplate {
|
||||||
name: String,
|
name: String,
|
||||||
template: String,
|
template: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ChatTemplateVersions {
|
pub enum ChatTemplateVersions {
|
||||||
Single(String),
|
Single(String),
|
||||||
@ -55,7 +57,7 @@ pub enum ChatTemplateVersions {
|
|||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<ChatTemplateVersions>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
pub completion_template: Option<String>,
|
pub completion_template: Option<String>,
|
||||||
@ -824,7 +826,7 @@ pub(crate) struct ChatRequest {
|
|||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: ToolChoice,
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
/// Response format constraints for the generation.
|
||||||
///
|
///
|
||||||
@ -846,6 +848,7 @@ pub enum ToolType {
|
|||||||
OneOf,
|
OneOf,
|
||||||
FunctionName(String),
|
FunctionName(String),
|
||||||
Function { function: FunctionName },
|
Function { function: FunctionName },
|
||||||
|
NoTool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||||
@ -853,27 +856,26 @@ pub struct FunctionName {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
|
||||||
#[serde(from = "ToolTypeDeserializer")]
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
pub struct ToolChoice(pub Option<ToolType>);
|
pub struct ToolChoice(pub Option<ToolType>);
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum ToolTypeDeserializer {
|
enum ToolTypeDeserializer {
|
||||||
None(Option<String>),
|
String(String),
|
||||||
Some(ToolType),
|
ToolType(ToolType),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ToolTypeDeserializer> for ToolChoice {
|
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
fn from(value: ToolTypeDeserializer) -> Self {
|
fn from(value: ToolTypeDeserializer) -> Self {
|
||||||
match value {
|
match value {
|
||||||
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
ToolTypeDeserializer::String(s) => match s.as_str() {
|
||||||
Some("none") => ToolChoice(None),
|
"none" => ToolChoice(Some(ToolType::NoTool)),
|
||||||
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
"auto" => ToolChoice(Some(ToolType::OneOf)),
|
||||||
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
_ => ToolChoice(Some(ToolType::FunctionName(s))),
|
||||||
None => ToolChoice(Some(ToolType::OneOf)),
|
|
||||||
},
|
},
|
||||||
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ use std::io::BufReader;
|
|||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
|
use text_generation_router::usage_stats;
|
||||||
use text_generation_router::{
|
use text_generation_router::{
|
||||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
||||||
};
|
};
|
||||||
@ -87,6 +88,10 @@ struct Args {
|
|||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
#[clap(long, env, default_value_t)]
|
||||||
|
disable_usage_stats: bool,
|
||||||
|
#[clap(long, env, default_value_t)]
|
||||||
|
disable_crash_reports: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
@ -128,6 +133,8 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
disable_usage_stats,
|
||||||
|
disable_crash_reports,
|
||||||
command,
|
command,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
@ -324,6 +331,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
|
||||||
|
|
||||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||||
@ -378,8 +386,47 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Only send usage stats when TGI is run in container and the function returns Some
|
||||||
|
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
||||||
|
|
||||||
|
let user_agent = if !disable_usage_stats && is_container {
|
||||||
|
let reduced_args = usage_stats::Args::new(
|
||||||
|
config.clone(),
|
||||||
|
tokenizer_class,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
disable_usage_stats,
|
||||||
|
disable_crash_reports,
|
||||||
|
);
|
||||||
|
Some(usage_stats::UserAgent::new(reduced_args))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ref ua) = user_agent {
|
||||||
|
let start_event =
|
||||||
|
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
start_event.send().await;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
let result = server::run(
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
model_info,
|
model_info,
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
@ -410,8 +457,41 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
print_schema_command,
|
print_schema_command,
|
||||||
)
|
)
|
||||||
.await?;
|
.await;
|
||||||
Ok(())
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
if let Some(ref ua) = user_agent {
|
||||||
|
let stop_event = usage_stats::UsageStatsEvent::new(
|
||||||
|
ua.clone(),
|
||||||
|
usage_stats::EventType::Stop,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
stop_event.send().await;
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref ua) = user_agent {
|
||||||
|
if !disable_crash_reports {
|
||||||
|
let error_event = usage_stats::UsageStatsEvent::new(
|
||||||
|
ua.clone(),
|
||||||
|
usage_stats::EventType::Error,
|
||||||
|
Some(e.to_string()),
|
||||||
|
);
|
||||||
|
error_event.send().await;
|
||||||
|
} else {
|
||||||
|
let unknow_error_event = usage_stats::UsageStatsEvent::new(
|
||||||
|
ua.clone(),
|
||||||
|
usage_stats::EventType::Error,
|
||||||
|
Some("unknow_error".to_string()),
|
||||||
|
);
|
||||||
|
unknow_error_event.send().await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Err(RouterError::WebServer(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
|
@ -24,7 +24,7 @@ use crate::{
|
|||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||||
VertexResponse,
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -1192,39 +1192,33 @@ async fn chat_completions(
|
|||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||||
// gen_text should be valid json
|
let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
|
||||||
let gen_text_value: Value =
|
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
|
||||||
(
|
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
"No function found in generated text".to_string(),
|
||||||
Json(ErrorResponse {
|
))?;
|
||||||
error: e.to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
let name = function
|
||||||
}),
|
.get("_name")
|
||||||
)
|
.and_then(Value::as_str)
|
||||||
})?;
|
.ok_or(InferError::ToolError(
|
||||||
|
"No _name found in generated text".to_string(),
|
||||||
|
))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let mut arguments = function.clone();
|
||||||
|
if let Value::Object(ref mut props) = arguments {
|
||||||
|
props.remove("_name");
|
||||||
|
}
|
||||||
|
|
||||||
let tool_calls = vec![ToolCall {
|
let tool_calls = vec![ToolCall {
|
||||||
id: "0".to_string(),
|
id: "0".to_string(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
description: None,
|
description: None,
|
||||||
name: gen_text_value
|
name,
|
||||||
.get("function")
|
arguments,
|
||||||
.and_then(|f| f.get("_name"))
|
|
||||||
.and_then(|name| name.as_str())
|
|
||||||
.unwrap_or("default_function_name")
|
|
||||||
.to_string(),
|
|
||||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
|
||||||
arguments: gen_text_value
|
|
||||||
.get("function")
|
|
||||||
.map(|f| {
|
|
||||||
let mut f_cloned = f.clone();
|
|
||||||
if let Value::Object(ref mut props) = f_cloned {
|
|
||||||
props.remove("_name");
|
|
||||||
}
|
|
||||||
f_cloned
|
|
||||||
})
|
|
||||||
.unwrap_or_default(),
|
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
(Some(tool_calls), None)
|
(Some(tool_calls), None)
|
||||||
@ -1498,6 +1492,7 @@ pub async fn run(
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
Function,
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
355
router/src/usage_stats.rs
Normal file
355
router/src/usage_stats.rs
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
use crate::config::Config;
|
||||||
|
use csv::ReaderBuilder;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::{
|
||||||
|
fs::File,
|
||||||
|
io::{self, BufRead},
|
||||||
|
path::Path,
|
||||||
|
process::Command,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct UserAgent {
|
||||||
|
pub uid: String,
|
||||||
|
pub args: Args,
|
||||||
|
pub env: Env,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserAgent {
|
||||||
|
pub fn new(reduced_args: Args) -> Self {
|
||||||
|
Self {
|
||||||
|
uid: Uuid::new_v4().to_string(),
|
||||||
|
args: reduced_args,
|
||||||
|
env: Env::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Debug)]
|
||||||
|
pub enum EventType {
|
||||||
|
Start,
|
||||||
|
Stop,
|
||||||
|
Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct UsageStatsEvent {
|
||||||
|
user_agent: UserAgent,
|
||||||
|
event_type: EventType,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
error_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UsageStatsEvent {
|
||||||
|
pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
user_agent,
|
||||||
|
event_type,
|
||||||
|
error_reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub async fn send(&self) {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||||
|
let body = serde_json::to_string(&self).unwrap();
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let _ = client
|
||||||
|
.post(TELEMETRY_URL)
|
||||||
|
.headers(headers)
|
||||||
|
.body(body)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct Args {
|
||||||
|
model_config: Option<Config>,
|
||||||
|
tokenizer_config: Option<String>,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
revision: Option<String>,
|
||||||
|
validation_workers: usize,
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
disable_usage_stats: bool,
|
||||||
|
disable_crash_reports: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
model_config: Option<Config>,
|
||||||
|
tokenizer_config: Option<String>,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
revision: Option<String>,
|
||||||
|
validation_workers: usize,
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
disable_usage_stats: bool,
|
||||||
|
disable_crash_reports: bool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
model_config,
|
||||||
|
tokenizer_config,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
disable_usage_stats,
|
||||||
|
disable_crash_reports,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency
|
||||||
|
#[derive(Serialize, Debug, Clone)]
|
||||||
|
pub struct Env {
|
||||||
|
git_sha: &'static str,
|
||||||
|
docker_label: &'static str,
|
||||||
|
nvidia_info: Option<Vec<NvidiaSmiInfo>>,
|
||||||
|
xpu_info: Option<Vec<XpuSmiInfo>>,
|
||||||
|
system_env: SystemInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Clone)]
|
||||||
|
struct NvidiaSmiInfo {
|
||||||
|
name: String,
|
||||||
|
pci_bus_id: String,
|
||||||
|
driver_version: String,
|
||||||
|
pstate: String,
|
||||||
|
pcie_link_gen_max: String,
|
||||||
|
pcie_link_gen_current: String,
|
||||||
|
temperature_gpu: String,
|
||||||
|
utilization_gpu: String,
|
||||||
|
utilization_memory: String,
|
||||||
|
memory_total: String,
|
||||||
|
memory_free: String,
|
||||||
|
memory_used: String,
|
||||||
|
reset_status_reset_required: String,
|
||||||
|
reset_status_drain_and_reset_recommended: String,
|
||||||
|
compute_cap: String,
|
||||||
|
ecc_errors_corrected_volatile_total: String,
|
||||||
|
mig_mode_current: String,
|
||||||
|
power_draw_instant: String,
|
||||||
|
power_limit: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NvidiaSmiInfo {
|
||||||
|
fn new() -> Option<Vec<NvidiaSmiInfo>> {
|
||||||
|
let output = Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
"--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit",
|
||||||
|
"--format=csv"
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||||
|
|
||||||
|
let mut rdr = ReaderBuilder::new()
|
||||||
|
.has_headers(true)
|
||||||
|
.from_reader(stdout.as_bytes());
|
||||||
|
|
||||||
|
let mut infos = Vec::new();
|
||||||
|
|
||||||
|
for result in rdr.records() {
|
||||||
|
let record = result.ok()?;
|
||||||
|
infos.push(NvidiaSmiInfo {
|
||||||
|
name: record[0].to_string(),
|
||||||
|
pci_bus_id: record[1].to_string(),
|
||||||
|
driver_version: record[2].to_string(),
|
||||||
|
pstate: record[3].to_string(),
|
||||||
|
pcie_link_gen_max: record[4].to_string(),
|
||||||
|
pcie_link_gen_current: record[5].to_string(),
|
||||||
|
temperature_gpu: record[6].to_string(),
|
||||||
|
utilization_gpu: record[7].to_string(),
|
||||||
|
utilization_memory: record[8].to_string(),
|
||||||
|
memory_total: record[9].to_string(),
|
||||||
|
memory_free: record[10].to_string(),
|
||||||
|
memory_used: record[11].to_string(),
|
||||||
|
reset_status_reset_required: record[12].to_string(),
|
||||||
|
reset_status_drain_and_reset_recommended: record[13].to_string(),
|
||||||
|
compute_cap: record[14].to_string(),
|
||||||
|
ecc_errors_corrected_volatile_total: record[15].to_string(),
|
||||||
|
mig_mode_current: record[16].to_string(),
|
||||||
|
power_draw_instant: record[17].to_string(),
|
||||||
|
power_limit: record[18].to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(infos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Clone)]
|
||||||
|
struct XpuSmiInfo {
|
||||||
|
device_id: usize,
|
||||||
|
gpu_utilization: f32,
|
||||||
|
gpu_power: f32,
|
||||||
|
gpu_core_temperature: f32,
|
||||||
|
gpu_memory_bandwidth_utilization: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl XpuSmiInfo {
|
||||||
|
/// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format
|
||||||
|
fn new() -> Option<Vec<XpuSmiInfo>> {
|
||||||
|
let output = Command::new("xpu-smi")
|
||||||
|
.args([
|
||||||
|
"dump", "-d", "-1", "-m",
|
||||||
|
"0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization
|
||||||
|
"-n", "1", "-j",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let mut infos = Vec::new();
|
||||||
|
|
||||||
|
let json_data: serde_json::Value = match serde_json::from_str(&stdout) {
|
||||||
|
Ok(data) => data,
|
||||||
|
Err(_) => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(metrics_data) = json_data.as_array() {
|
||||||
|
for entry in metrics_data {
|
||||||
|
let device_id = entry["deviceId"].as_u64()? as usize;
|
||||||
|
let gpu_utilization = entry["metrics"][0].as_f64()? as f32;
|
||||||
|
let gpu_power = entry["metrics"][1].as_f64()? as f32;
|
||||||
|
let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32;
|
||||||
|
let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32;
|
||||||
|
|
||||||
|
infos.push(XpuSmiInfo {
|
||||||
|
device_id,
|
||||||
|
gpu_utilization,
|
||||||
|
gpu_power,
|
||||||
|
gpu_core_temperature,
|
||||||
|
gpu_memory_bandwidth_utilization,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(infos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Debug, Clone)]
|
||||||
|
pub struct SystemInfo {
|
||||||
|
cpu_count: usize,
|
||||||
|
cpu_type: String,
|
||||||
|
total_memory: u64,
|
||||||
|
architecture: String,
|
||||||
|
platform: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SystemInfo {
|
||||||
|
fn new() -> Self {
|
||||||
|
let mut system = sysinfo::System::new_all();
|
||||||
|
system.refresh_all();
|
||||||
|
|
||||||
|
let cpu_count = system.cpus().len();
|
||||||
|
let cpu_type = system.cpus()[0].brand().to_string();
|
||||||
|
let total_memory = system.total_memory();
|
||||||
|
let architecture = std::env::consts::ARCH.to_string();
|
||||||
|
let platform = format!(
|
||||||
|
"{}-{}-{}",
|
||||||
|
std::env::consts::OS,
|
||||||
|
std::env::consts::FAMILY,
|
||||||
|
std::env::consts::ARCH
|
||||||
|
);
|
||||||
|
Self {
|
||||||
|
cpu_count,
|
||||||
|
cpu_type,
|
||||||
|
total_memory,
|
||||||
|
architecture,
|
||||||
|
platform,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Env {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Env {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
system_env: SystemInfo::new(),
|
||||||
|
nvidia_info: NvidiaSmiInfo::new(),
|
||||||
|
xpu_info: XpuSmiInfo::new(),
|
||||||
|
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||||
|
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_container() -> io::Result<bool> {
|
||||||
|
let path = Path::new("/proc/self/cgroup");
|
||||||
|
let file = File::open(path)?;
|
||||||
|
let reader = io::BufReader::new(file);
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
// Check for common container runtimes
|
||||||
|
if line.contains("/docker/")
|
||||||
|
|| line.contains("/docker-")
|
||||||
|
|| line.contains("/kubepods/")
|
||||||
|
|| line.contains("/kubepods-")
|
||||||
|
|| line.contains("containerd")
|
||||||
|
|| line.contains("crio")
|
||||||
|
|| line.contains("podman")
|
||||||
|
{
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
@ -5,6 +5,7 @@ include Makefile-awq
|
|||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
include Makefile-selective-scan
|
include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
|
include Makefile-fbgemm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
@ -27,8 +28,9 @@ install-server: gen-server
|
|||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||||
pip install -e ".[bnb]"
|
pip install -e ".[bnb]"
|
||||||
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
||||||
@ -36,5 +38,6 @@ run-dev:
|
|||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements_cuda.txt --without-hashes --with cuda
|
poetry export -o requirements_cuda.txt --without-hashes
|
||||||
poetry export -o requirements_rocm.txt --without-hashes
|
poetry export -o requirements_rocm.txt --without-hashes
|
||||||
|
poetry export -o requirements_intel.txt --without-hashes
|
||||||
|
15
server/Makefile-fbgemm
Normal file
15
server/Makefile-fbgemm
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
|
||||||
|
|
||||||
|
build-fbgemm:
|
||||||
|
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
|
||||||
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||||
|
cp fbgemm_remove_unused.patch fbgemm && \
|
||||||
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
cd fbgemm_gpu && \
|
||||||
|
pip install -r requirements.txt && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
|
||||||
|
|
||||||
|
install-fbgemm: build-fbgemm
|
||||||
|
cd fbgemm/fbgemm_gpu && \
|
||||||
|
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
|
@ -1,4 +1,4 @@
|
|||||||
flash_att_v2_commit_cuda := v2.5.9.post1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/Narsil/vllm.git vllm; \
|
git clone https://github.com/Narsil/vllm.git vllm; \
|
||||||
fi
|
fi
|
||||||
cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
|
||||||
|
|
||||||
install-vllm-cuda: build-vllm-cuda
|
install-vllm-cuda: build-vllm-cuda
|
||||||
cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e .
|
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
|
||||||
|
|
||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
|
306
server/fbgemm_remove_unused.patch
Normal file
306
server/fbgemm_remove_unused.patch
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
index 2244ea6f..96265a48 100644
|
||||||
|
--- a/fbgemm_gpu/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/CMakeLists.txt
|
||||||
|
@@ -94,14 +94,14 @@ endif()
|
||||||
|
# Build Experimental Modules
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
- # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
- add_subdirectory(experimental/example)
|
||||||
|
-endif()
|
||||||
|
-
|
||||||
|
-if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
- add_subdirectory(experimental/gemm)
|
||||||
|
-endif()
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
+# # TODO: Figure out NCCL/RCCL integration with ROCm
|
||||||
|
+# add_subdirectory(experimental/example)
|
||||||
|
+# endif()
|
||||||
|
+
|
||||||
|
+# if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
+# add_subdirectory(experimental/gemm)
|
||||||
|
+# endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
|
||||||
|
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
|
||||||
|
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
index c56773fe..0c0d349e 100644
|
||||||
|
--- a/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
+++ b/fbgemm_gpu/FbgemmGpu.cmake
|
||||||
|
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
set(fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
|
||||||
|
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
|
||||||
|
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_cpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_cpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_meta.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_cpu.cpp
|
||||||
|
+ # src/sparse_ops/sparse_ops_meta.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
|
||||||
|
+)
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
list(APPEND fbgemm_gpu_sources_static_cpu
|
||||||
|
- codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
- codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
- src/memory_utils/memory_utils.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops.cpp
|
||||||
|
- src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_host.cpp
|
||||||
|
+ # codegen/utils/embedding_bounds_check_host.cpp
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
|
||||||
|
+ # src/memory_utils/memory_utils.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cpp
|
||||||
|
+ # src/memory_utils/memory_utils_ops_cpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
|
||||||
|
src/quantize_ops/quantize_ops_gpu.cpp
|
||||||
|
- src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
- src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
- src/metric_ops/metric_ops_host.cpp
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
- src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ # src/sparse_ops/sparse_ops_gpu.cpp
|
||||||
|
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
|
||||||
|
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops_host.cpp
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
|
||||||
|
+ # src/input_combine_ops/input_combine_gpu.cpp
|
||||||
|
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
if(NVML_LIB_PATH OR USE_ROCM)
|
||||||
|
message(STATUS "Adding merge_pooled_embeddings sources")
|
||||||
|
@@ -516,36 +518,36 @@ endif()
|
||||||
|
|
||||||
|
if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
set(fbgemm_gpu_sources_static_gpu
|
||||||
|
- codegen/utils/embedding_bounds_check.cu
|
||||||
|
- codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
- src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
- src/histogram_binning_calibration_ops.cu
|
||||||
|
- src/input_combine_ops/input_combine.cu
|
||||||
|
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
- src/memory_utils/memory_utils.cu
|
||||||
|
- src/memory_utils/memory_utils_ops.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
- src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
- src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
- src/metric_ops/metric_ops.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
+ # codegen/utils/embedding_bounds_check.cu
|
||||||
|
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
|
||||||
|
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
|
||||||
|
+ # src/histogram_binning_calibration_ops.cu
|
||||||
|
+ # src/input_combine_ops/input_combine.cu
|
||||||
|
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
|
||||||
|
+ # src/memory_utils/memory_utils.cu
|
||||||
|
+ # src/memory_utils/memory_utils_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
|
||||||
|
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
|
||||||
|
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
|
||||||
|
+ # src/layout_transform_ops/layout_transform_ops.cu
|
||||||
|
+ # src/metric_ops/metric_ops.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
|
||||||
|
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
|
||||||
|
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
|
||||||
|
src/quantize_ops/quantize_bfloat16.cu
|
||||||
|
src/quantize_ops/quantize_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_fused_8bit_rowwise.cu
|
||||||
|
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
|
||||||
|
src/quantize_ops/quantize_msfp.cu
|
||||||
|
src/quantize_ops/quantize_padded_fp8_rowwise.cu
|
||||||
|
src/quantize_ops/quantize_mx.cu
|
||||||
|
- src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
- src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
- src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
- src/sparse_ops/sparse_group_index.cu
|
||||||
|
- src/sparse_ops/sparse_index_add.cu
|
||||||
|
- src/sparse_ops/sparse_index_select.cu
|
||||||
|
- src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
- src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
- src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
- src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
- src/sparse_ops/sparse_permute102.cu
|
||||||
|
- src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
- src/sparse_ops/sparse_range.cu
|
||||||
|
- src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
- src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
- src/sparse_ops/sparse_zipf.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
- src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
- src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
- src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
- src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
- src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
- src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
- src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
- src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ # src/sparse_ops/sparse_async_cumsum.cu
|
||||||
|
+ # src/sparse_ops/sparse_block_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_bucketize_features.cu
|
||||||
|
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
|
||||||
|
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_group_index.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_add.cu
|
||||||
|
+ # src/sparse_ops/sparse_index_select.cu
|
||||||
|
+ # src/sparse_ops/sparse_invert_permute.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_backward.cu
|
||||||
|
+ # src/sparse_ops/sparse_pack_segments_forward.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_1d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_2d.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute102.cu
|
||||||
|
+ # src/sparse_ops/sparse_permute_embeddings.cu
|
||||||
|
+ # src/sparse_ops/sparse_range.cu
|
||||||
|
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
|
||||||
|
+ # src/sparse_ops/sparse_segment_sum_csr.cu
|
||||||
|
+ # src/sparse_ops/sparse_zipf.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_find.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate.cu
|
||||||
|
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
|
||||||
|
+ # src/split_embeddings_cache/lxu_cache.cu
|
||||||
|
+ # src/split_embeddings_cache/linearize_cache_indices.cu
|
||||||
|
+ # src/split_embeddings_cache/reset_weight_momentum.cu
|
||||||
|
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/get_infos_metadata.cu
|
||||||
|
+ # src/split_embeddings_utils/radix_sort_pairs.cu
|
||||||
|
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
|
||||||
|
+ )
|
||||||
|
|
||||||
|
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
|
||||||
|
PROPERTIES COMPILE_OPTIONS
|
||||||
|
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
index 01f1d6ab..a6b8d7a8 100644
|
||||||
|
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
|
||||||
|
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
|
||||||
|
${THIRDPARTY}/json/include
|
||||||
|
${NCCL_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
-set(attention_ops_sources
|
||||||
|
- src/attention/attention.cpp
|
||||||
|
- src/attention/gqa_attn_splitk.cu)
|
||||||
|
+# set(attention_ops_sources
|
||||||
|
+# src/attention/attention.cpp
|
||||||
|
+# src/attention/gqa_attn_splitk.cu)
|
||||||
|
|
||||||
|
set(quantize_ops_sources
|
||||||
|
src/quantize/cutlass_extensions.cu
|
||||||
|
src/quantize/quantize.cu
|
||||||
|
src/quantize/quantize.cpp)
|
||||||
|
|
||||||
|
-set(comm_ops_sources
|
||||||
|
- src/comm/car.cu
|
||||||
|
- src/comm/car.cpp)
|
||||||
|
+# set(comm_ops_sources
|
||||||
|
+# src/comm/car.cu
|
||||||
|
+# src/comm/car.cpp)
|
||||||
|
|
||||||
|
set(experimental_gen_ai_cpp_source_files
|
||||||
|
- ${attention_ops_sources}
|
||||||
|
+ # ${attention_ops_sources}
|
||||||
|
${quantize_ops_sources}
|
||||||
|
- ${comm_ops_sources})
|
||||||
|
+ # ${comm_ops_sources}
|
||||||
|
+)
|
||||||
|
|
||||||
|
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
|
||||||
|
PROPERTIES INCLUDE_DIRECTORIES
|
11
server/fix_torch90a.sh
Executable file
11
server/fix_torch90a.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script is required to patch torch < 2.4
|
||||||
|
# It adds the 90a cuda target (H100)
|
||||||
|
# This target is required to build FBGEMM kernels
|
||||||
|
|
||||||
|
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
|
||||||
|
|
||||||
|
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
|
||||||
|
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
|
||||||
|
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch
|
1135
server/poetry.lock
generated
1135
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.19.1"
|
tokenizers = "^0.19.1"
|
||||||
huggingface-hub = "^0.23"
|
huggingface-hub = "^0.23"
|
||||||
transformers = "^4.41"
|
transformers = "^4.42"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -1,48 +1,50 @@
|
|||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,48 +1,50 @@
|
|||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,48 +1,50 @@
|
|||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.42.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -87,15 +88,17 @@ def serve(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(lora_adapter_ids) > 0:
|
if len(lora_adapter_ids) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
logger.warning,
|
||||||
|
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||||
# and warn the user
|
# and warn the user
|
||||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
|
logger.warning,
|
||||||
|
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
|
||||||
)
|
)
|
||||||
global CUDA_GRAPHS
|
global CUDA_GRAPHS
|
||||||
CUDA_GRAPHS = None
|
CUDA_GRAPHS = None
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
@ -43,6 +44,7 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
@ -82,6 +84,8 @@ def paged_attention(
|
|||||||
# by the current path
|
# by the current path
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -89,6 +93,7 @@ def paged_attention(
|
|||||||
None,
|
None,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
|
None, # pad_k
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables,
|
||||||
None,
|
None,
|
||||||
@ -100,11 +105,14 @@ def paged_attention(
|
|||||||
True, # causal
|
True, # causal
|
||||||
-1, # Window_left
|
-1, # Window_left
|
||||||
-1, # Window right
|
-1, # Window right
|
||||||
|
softcap,
|
||||||
False, # return softmax
|
False, # return softmax
|
||||||
None, # generator
|
None, # generator
|
||||||
)
|
)
|
||||||
return out2[0]
|
return out2[0]
|
||||||
else:
|
else:
|
||||||
|
if softcap is not None:
|
||||||
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
@ -205,6 +213,7 @@ if V2:
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
@ -218,6 +227,7 @@ if V2:
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -226,6 +236,7 @@ if V2:
|
|||||||
causal,
|
causal,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
0,
|
0,
|
||||||
|
softcap,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -241,11 +252,14 @@ else:
|
|||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -136,7 +137,10 @@ if ENGINE != "triton":
|
|||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if major >= 8:
|
if major >= 8:
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
@ -4,19 +4,11 @@ from functools import lru_cache
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
from bitsandbytes.nn import Int8Params, Params4bit
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
from loguru import logger
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.weights import Weight
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def warn_deprecate_bnb():
|
|
||||||
logger.warning(
|
|
||||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBWeight(Weight):
|
class BNBWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
@ -82,7 +74,7 @@ class Linear8bitLt(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBFP4Weight(Weight):
|
class BNBFP4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
@ -90,7 +82,7 @@ class BNBFP4Weight(Weight):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BNBNF4Weight(Weight):
|
class BNBNF4Weight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EETQWeight(Weight):
|
class EETQWeight(UnquantizedWeight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -34,6 +34,30 @@ class Exl2Weight(Weight):
|
|||||||
class Exl2WeightsLoader(WeightsLoader):
|
class Exl2WeightsLoader(WeightsLoader):
|
||||||
"""Loader for exl2-quantized weights."""
|
"""Loader for exl2-quantized weights."""
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||||
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||||
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||||
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||||
|
|
||||||
|
return Exl2Weight(
|
||||||
|
q_weight=q_weight,
|
||||||
|
q_scale=q_scale,
|
||||||
|
q_invperm=q_invperm,
|
||||||
|
q_scale_max=q_scale_max,
|
||||||
|
q_groups=q_groups,
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
@ -43,46 +67,12 @@ class Exl2WeightsLoader(WeightsLoader):
|
|||||||
raise RuntimeError("Column-packed weights are not supported for exl")
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||||
|
|
||||||
def get_weights_col(self, weights: Weights, prefix: str):
|
def get_weights_col(self, weights: Weights, prefix: str):
|
||||||
try:
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
return self.get_weights(weights, prefix)
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
||||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
||||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
||||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
||||||
|
|
||||||
return Exl2Weight(
|
|
||||||
q_weight=q_weight,
|
|
||||||
q_scale=q_scale,
|
|
||||||
q_invperm=q_invperm,
|
|
||||||
q_scale_max=q_scale_max,
|
|
||||||
q_groups=q_groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
try:
|
# Sharding is not yet supported, so we return the weights as-is.
|
||||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
return self.get_weights(weights, prefix)
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
||||||
)
|
|
||||||
|
|
||||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
||||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
||||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
||||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
||||||
|
|
||||||
return Exl2Weight(
|
|
||||||
q_weight=q_weight,
|
|
||||||
q_scale=q_scale,
|
|
||||||
q_invperm=q_invperm,
|
|
||||||
q_scale_max=q_scale_max,
|
|
||||||
q_groups=q_groups,
|
|
||||||
)
|
|
||||||
|
@ -1,8 +1,29 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union, List
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
WeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.log import log_master, log_once
|
||||||
|
|
||||||
|
FBGEMM_MM_AVAILABLE = False
|
||||||
|
FBGEMM_DYN_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
import fbgemm_gpu.experimental.gen_ai
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
FBGEMM_MM_AVAILABLE = major == 9
|
||||||
|
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear() -> torch.nn.Module:
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
@ -11,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
if major == 8 and minor < 9:
|
if major == 8:
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
@ -21,12 +42,19 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(
|
||||||
device = weight.device
|
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||||
|
):
|
||||||
|
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||||
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
|
)
|
||||||
|
return qweight, scale
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
# Calculate the scale as dtype max divided by absmax
|
# Calculate the scale as dtype max divided by absmax
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||||
# scale and clamp the tensor to bring it to
|
# scale and clamp the tensor to bring it to
|
||||||
# the representative range of float8 data type
|
# the representative range of float8 data type
|
||||||
# (as default cast is unsaturated)
|
# (as default cast is unsaturated)
|
||||||
@ -38,28 +66,178 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
|||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
|
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||||
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
self.to_fp8 = to_fp8
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(
|
||||||
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
|
]
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = weights.get_tensor(
|
||||||
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Weight(Weight):
|
class Fp8Weight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
activation_scale_ub: Optional[float] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return get_fp8_linear()(self.weight, bias)
|
if self.weight_scale is None:
|
||||||
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
|
return get_fp8_linear().from_fp8(
|
||||||
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
qweight,
|
||||||
|
scale,
|
||||||
|
scale_upper_bound,
|
||||||
bias,
|
bias,
|
||||||
|
dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = weight.dtype
|
if FBGEMM_MM_AVAILABLE:
|
||||||
self.qweight, self.scale = fp8_quantize(weight)
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.qweight = qweight
|
||||||
|
self.scale = scale
|
||||||
|
self.scale_upper_bound = (
|
||||||
|
torch.tensor(
|
||||||
|
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||||
|
)
|
||||||
|
if scale_upper_bound is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||||
|
return cls(
|
||||||
|
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
|
return cls(
|
||||||
|
qweight=weight,
|
||||||
|
scale=scale,
|
||||||
|
scale_upper_bound=input_scale,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
qinput, scale = fp8_quantize(input)
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
qinput, scale = fp8_quantize(
|
||||||
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
|
)
|
||||||
|
|
||||||
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
use_fast_accum=True,
|
||||||
|
bias=self.bias,
|
||||||
|
)
|
||||||
|
return y.to(self.dtype)
|
||||||
|
|
||||||
|
qinput, scale = fp8_quantize(input, scalar=True)
|
||||||
output, _ = torch._scaled_mm(
|
output, _ = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight.t(),
|
self.qweight.t(),
|
||||||
|
@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
if can_use_gptq_marlin(
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
quantize=self.quantize,
|
||||||
|
sym=self.sym,
|
||||||
|
):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import (
|
||||||
|
HAS_EXLLAMA,
|
||||||
|
CAN_EXLLAMA,
|
||||||
|
GPTQWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_exllama:
|
||||||
|
if not HAS_EXLLAMA:
|
||||||
|
if CAN_EXLLAMA:
|
||||||
|
log_once(
|
||||||
|
logger.warning,
|
||||||
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||||
|
)
|
||||||
|
use_exllama = False
|
||||||
|
else:
|
||||||
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if use_exllama and g_idx is not None:
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
@ -9,11 +9,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("exllamav2_kernels not installed.")
|
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.is_marlin_24 = is_marlin_24
|
self.is_marlin_24 = is_marlin_24
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
|
if is_marlin_24:
|
||||||
|
try:
|
||||||
|
B = weights.get_tensor(f"{prefix}.B_24")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
||||||
|
s = weights.get_tensor(f"{prefix}.s")
|
||||||
|
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
B = weights.get_tensor(f"{prefix}.B")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
s = weights.get_tensor(f"{prefix}.s")
|
||||||
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
@ -474,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -484,9 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
qweight, scale = fp8_quantize(weight)
|
scales = scales.unsqueeze(0)
|
||||||
scale = scale.to(torch.float16)
|
if scales.shape[1] == 1:
|
||||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
out_features, in_features = qweight.shape
|
||||||
|
scales = scales.repeat(1, out_features)
|
||||||
|
qweight, scales = repack_fp8_for_marlin(qweight, scales)
|
||||||
|
|
||||||
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||||
out_features = scales.shape[1]
|
out_features = scales.shape[1]
|
||||||
@ -500,6 +532,15 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scales = fp8_quantize(weight)
|
||||||
|
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
|
||||||
|
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
@ -553,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
return packed
|
return packed
|
||||||
|
|
||||||
|
|
||||||
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Repack FP8 tensor for GPTQ-Marlin.
|
Repack FP8 tensor for GPTQ-Marlin.
|
||||||
"""
|
"""
|
||||||
@ -570,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
|||||||
qweight, perm, in_features, out_features, 8
|
qweight, perm, in_features, out_features, 8
|
||||||
)
|
)
|
||||||
|
|
||||||
scales = scale.reshape(1, 1).repeat(1, out_features)
|
|
||||||
scales = permute_scales(scales)
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
return repacked, scales
|
return repacked, scales
|
||||||
@ -583,7 +623,7 @@ class MarlinWeight(Weight):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||||
s (torch.Tensor): float16 scales.
|
s (torch.Tensor): bfloat16/float16 scales.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
B: torch.Tensor
|
B: torch.Tensor
|
||||||
@ -591,7 +631,7 @@ class MarlinWeight(Weight):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.B.dtype == torch.int32
|
assert self.B.dtype == torch.int32
|
||||||
assert self.s.dtype == torch.float16
|
assert self.s.dtype in [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return MarlinLinear(weight=self, bias=bias)
|
return MarlinLinear(weight=self, bias=bias)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling[
|
max_position_embeddings=rope_scaling[
|
||||||
@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1,
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] in ["su", "longrope"]:
|
elif rope_scaling["type"] in ["su", "longrope"]:
|
||||||
short_factor = torch.tensor(
|
short_factor = torch.tensor(
|
||||||
@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
mscale = rope_scaling.get("mscale", 1.0)
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling[
|
max_position_embeddings=rope_scaling[
|
||||||
@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1,
|
beta_slow=1,
|
||||||
|
mscale=mscale,
|
||||||
|
mscale_all_dim=mscale_all_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim):
|
|||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
def get_mscale(scale=1):
|
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
|
||||||
if scale <= 1:
|
if scale <= 1:
|
||||||
return 1.0
|
return 1.0
|
||||||
return 0.1 * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
attn_factor,
|
attn_factor,
|
||||||
beta_fast,
|
beta_fast,
|
||||||
beta_slow,
|
beta_slow,
|
||||||
|
mscale: float,
|
||||||
|
mscale_all_dim: float,
|
||||||
):
|
):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
|
self.mscale_all_dim = mscale_all_dim
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
self.mscale = float(
|
self.mscale = float(
|
||||||
get_mscale(self.scaling_factor) * self.attn_factor
|
get_mscale(self.scaling_factor, mscale)
|
||||||
|
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||||
|
* self.attn_factor
|
||||||
) # Get n-d magnitude scaling corrected for interpolation
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings or True:
|
||||||
inv_freq_extrapolation = _create_inv_freq(
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
self.dim, self.base, self.inv_freq.device
|
self.dim, self.base, self.inv_freq.device
|
||||||
)
|
)
|
||||||
@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self.base,
|
self.base,
|
||||||
self.max_position_embeddings,
|
self.max_position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
inv_freq_mask = (
|
inv_freq_mask = (
|
||||||
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self.mscale = float(
|
|
||||||
get_mscale(self.scaling_factor) * self.attn_factor
|
|
||||||
) # Get n-d magnitude scaling corrected for interpolation
|
|
||||||
|
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
"BLOOMSharded",
|
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"GalacticaSharded",
|
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
@ -61,6 +60,10 @@ FLASH_ATTENTION = True
|
|||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
|
FlashDeepseekV2ForCausalLM,
|
||||||
|
DeepseekV2Config,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
@ -121,7 +124,7 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
@ -133,7 +136,7 @@ MAMBA_AVAILABLE = True
|
|||||||
try:
|
try:
|
||||||
from text_generation_server.models.mamba import Mamba
|
from text_generation_server.models.mamba import Mamba
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Mamba: {e}")
|
log_master(logger.warning, f"Could not import Mamba: {e}")
|
||||||
MAMBA_AVAILABLE = False
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
@ -141,6 +144,11 @@ if MAMBA_AVAILABLE:
|
|||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
|
DEEPSEEK_V2 = {
|
||||||
|
"type": "deepseek_v2",
|
||||||
|
"name": "Deepseek V2",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
|
}
|
||||||
IDEFICS2 = {
|
IDEFICS2 = {
|
||||||
"type": "idefics2",
|
"type": "idefics2",
|
||||||
"name": "Idefics 2",
|
"name": "Idefics 2",
|
||||||
@ -298,10 +306,34 @@ def get_model(
|
|||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
if quantization_config is not None and quantize is None:
|
||||||
|
method = quantization_config.get("quant_method", None)
|
||||||
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
|
quantize = method
|
||||||
|
elif method == "fbgemm_fp8":
|
||||||
|
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||||
|
quantize = "fp8"
|
||||||
|
else:
|
||||||
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
elif quantize == "fp8":
|
||||||
|
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
|
||||||
|
|
||||||
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
|
# fbgemm kernels are fp8xfp8->bf16
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
@ -318,11 +350,6 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
set_speculate(0)
|
set_speculate(0)
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
model_type = config_dict.get("model_type", None)
|
|
||||||
|
|
||||||
speculator = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
medusa_model_id = model_id
|
medusa_model_id = model_id
|
||||||
@ -424,7 +451,9 @@ def get_model(
|
|||||||
|
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
log_master(
|
||||||
|
logger.info, f"Using speculation {method} with {speculate} input ids."
|
||||||
|
)
|
||||||
|
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
# TODO: fix how we determine model type for Mamba
|
# TODO: fix how we determine model type for Mamba
|
||||||
@ -435,14 +464,6 @@ def get_model(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not determine model type for {model_id} revision {revision}"
|
f"Could not determine model type for {model_id} revision {revision}"
|
||||||
)
|
)
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
|
||||||
if quantization_config is not None and quantize is None:
|
|
||||||
method = quantization_config.get("quant_method", None)
|
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
|
||||||
logger.info(f"Auto selecting quantization method {method}")
|
|
||||||
quantize = method
|
|
||||||
else:
|
|
||||||
logger.info(f"Unknown quantization method {method}")
|
|
||||||
|
|
||||||
if quantize == "exl2" and sharded:
|
if quantize == "exl2" and sharded:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -459,7 +480,40 @@ def get_model(
|
|||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == DEEPSEEK_V2:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV2Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -551,7 +605,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -573,6 +627,10 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGPTNeoXForCausalLM,
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
@ -582,6 +640,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=GPTNeoXConfig,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
@ -492,7 +492,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatchKeysLast(Batch):
|
class CausalLMBatchKeysLast(CausalLMBatch):
|
||||||
keys_head_dim_last: bool = False
|
keys_head_dim_last: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -544,7 +544,12 @@ class CausalLM(Model):
|
|||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
if config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
elif config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
weights_loader = get_loader(
|
weights_loader = get_loader(
|
||||||
|
@ -0,0 +1,980 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.common import Seqlen
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts(config, prefix: str, mat: str, weights: Weights):
|
||||||
|
if config.quantize is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Deepseek V2 does not support weight quantization yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mat in ["gate_proj", "up_proj", "down_proj"]
|
||||||
|
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.moe_intermediate_size % world_size == 0
|
||||||
|
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
block_size = config.moe_intermediate_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
tensor = torch.empty(
|
||||||
|
(config.n_routed_experts * block_size, config.hidden_size),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(config.n_routed_experts):
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
|
||||||
|
if mat == "down_proj":
|
||||||
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
|
else:
|
||||||
|
expert_slice = slice_[start:stop]
|
||||||
|
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||||
|
dtype=weights.dtype
|
||||||
|
).to(device=weights.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
|
# Output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
|
):
|
||||||
|
out = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out, reduce=reduce)
|
||||||
|
else:
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSparseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix, config: DeepseekV2Config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.n_routed_experts = config.n_routed_experts
|
||||||
|
self.n_expert_group = config.n_group
|
||||||
|
self.topk_group = config.topk_group
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
gate_proj = _load_experts(
|
||||||
|
config, f"{prefix}.experts", "gate_proj", weights
|
||||||
|
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
||||||
|
|
||||||
|
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
|
||||||
|
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
|
||||||
|
|
||||||
|
self.down_proj = (
|
||||||
|
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
|
||||||
|
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV2MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=self.norm_topk_prob,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
out = (
|
||||||
|
fused_experts(
|
||||||
|
x,
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.down_proj,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
* self.routed_scaling_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = config.moe_intermediate_size
|
||||||
|
self.n_routed_experts = config.n_routed_experts
|
||||||
|
self.n_expert_group = config.n_group
|
||||||
|
self.topk_group = config.topk_group
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
#
|
||||||
|
# Seems like no one quantizes the gate.
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.experts = [
|
||||||
|
DeepseekV2MLP(
|
||||||
|
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
|
||||||
|
)
|
||||||
|
for i in range(self.n_routed_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV2MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gate_logits: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
# gate_logits: (sequence_length, n_experts)
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=self.norm_topk_prob,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def moe_infer_gpu(
|
||||||
|
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||||
|
):
|
||||||
|
weights = torch.zeros(
|
||||||
|
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
weights.scatter_(1, topk_ids, topk_weight)
|
||||||
|
|
||||||
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
# Add expert output to out with masking
|
||||||
|
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV2Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||||
|
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV2MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV2Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV2Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
# Functions below are from vLLM:
|
||||||
|
#
|
||||||
|
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
|
||||||
|
#
|
||||||
|
# Remove after we have synced our version with upstream.
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_config(
|
||||||
|
M: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
}
|
||||||
|
if M <= E:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_fp8: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
import triton.language as tl
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
get_moe_configs,
|
||||||
|
invoke_fused_moe_kernel,
|
||||||
|
moe_align_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
E, N, _ = w1.shape
|
||||||
|
|
||||||
|
if override_config:
|
||||||
|
config = override_config
|
||||||
|
else:
|
||||||
|
# First try to load optimal config from the file
|
||||||
|
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||||
|
|
||||||
|
if configs:
|
||||||
|
# If an optimal configuration map has been found, look up the
|
||||||
|
# optimal config
|
||||||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Else use the default config
|
||||||
|
config = get_default_config(
|
||||||
|
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache1 = torch.empty(
|
||||||
|
(M, topk_ids.shape[1], N),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache2 = torch.empty(
|
||||||
|
(M * topk_ids.shape[1], N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache3 = torch.empty(
|
||||||
|
(M, topk_ids.shape[1], w2.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
|
topk_ids, config["BLOCK_SIZE_M"], E
|
||||||
|
)
|
||||||
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
intermediate_cache1,
|
||||||
|
a1_scale,
|
||||||
|
w1_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(
|
||||||
|
intermediate_cache2,
|
||||||
|
w2,
|
||||||
|
intermediate_cache3,
|
||||||
|
a2_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
return torch.sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
dim=1,
|
||||||
|
out=hidden_states,
|
||||||
|
)
|
||||||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Config(PretrainedConfig):
|
class Gemma2Config(PretrainedConfig):
|
||||||
@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.head_dim
|
head_size = config.head_dim
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
@ -188,6 +189,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
self.softcap = config.attn_logit_softcapping
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
@ -245,6 +247,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
window_size_left=self.window_size,
|
window_size_left=self.window_size,
|
||||||
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -258,6 +261,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -465,6 +469,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.softcap = config.final_logit_softcapping
|
||||||
|
assert isinstance(self.softcap, float)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -494,4 +500,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
|||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
logits /= self.softcap
|
||||||
|
logits = torch.tanh(logits)
|
||||||
|
logits *= self.softcap
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.head_dim
|
head_size = config.head_dim
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -42,16 +41,15 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_fp8(weights: Weights):
|
def no_fp8(weights: Weights):
|
||||||
|
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||||
weights_loader = weights.weights_loader
|
weights_loader = weights.weights_loader
|
||||||
if (
|
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||||
isinstance(weights_loader, DefaultWeightsLoader)
|
weights_loader = HybridFP8UnquantLoader(
|
||||||
and weights_loader.weight_class is Fp8Weight
|
weights_loader.activation_scale_ub, to_fp8=False
|
||||||
):
|
)
|
||||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
|
||||||
|
|
||||||
with weights.use_loader(weights_loader):
|
with weights.use_loader(weights_loader):
|
||||||
yield
|
yield
|
||||||
@ -418,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
|
# Skip fp8 quant for first and last layers
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
with no_fp8(weights):
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=0,
|
||||||
|
prefix=(
|
||||||
|
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.extend(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
@ -430,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
# Skip first and last layers
|
||||||
|
for layer_id in range(1, config.num_hidden_layers - 1)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with no_fp8(weights):
|
||||||
|
last_layer_id = config.num_hidden_layers - 1
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaLayer(
|
||||||
|
index=last_layer_id,
|
||||||
|
prefix=(
|
||||||
|
f"model.layers.{last_layer_id}"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.model.layers.{last_layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -52,6 +52,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
@ -138,16 +139,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"num_key_value_heads": "num_attention_heads",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||||
if isinstance(weight, torch.Tensor):
|
if isinstance(weight, UnquantizedWeight):
|
||||||
# Only on non quantized versions
|
# Only on non quantized versions
|
||||||
weight = (
|
weight.weight = (
|
||||||
weight.view(
|
weight.weight.view(
|
||||||
num_heads,
|
num_heads,
|
||||||
3,
|
3,
|
||||||
head_size,
|
head_size,
|
||||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
if config.use_bias:
|
if config.use_bias:
|
||||||
w = [
|
w = [
|
||||||
|
@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_impl = config.attn_config["attn_impl"]
|
attn_impl = config.attn_config.attn_impl
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
self.qk_ln = config.attn_config["qk_ln"]
|
self.qk_ln = config.attn_config.qk_ln
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
d_model = config.d_model
|
d_model = config.d_model
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
self.softmax_scale = config.attn_config.softmax_scale
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||||
|
|
||||||
if self.n_heads % weights.process_group.size() != 0:
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_impl = config.attn_config["attn_impl"]
|
attn_impl = config.attn_config.attn_impl
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
self.qk_ln = config.attn_config["qk_ln"]
|
self.qk_ln = config.attn_config.qk_ln
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
d_model = config.d_model
|
d_model = config.d_model
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
self.softmax_scale = config.attn_config.softmax_scale
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config.attn_pdrop
|
||||||
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||||
self.Wqkv = TensorParallelColumnLinear.load(
|
self.Wqkv = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
|
|||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
if config.attn_config["attn_type"] != "multihead_attention":
|
if config.attn_config.attn_type != "multihead_attention":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
f"""Not implemented attn {config.attn_config.attn_type}"""
|
||||||
)
|
)
|
||||||
resid_pdrop = config.resid_pdrop
|
resid_pdrop = config.resid_pdrop
|
||||||
if config.no_bias:
|
if config.no_bias:
|
||||||
@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
self.world_size = weights.process_group.size()
|
self.world_size = weights.process_group.size()
|
||||||
self.rank = weights.process_group.rank()
|
self.rank = weights.process_group.rank()
|
||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config.attn_impl
|
||||||
self.prefix_lm = config.attn_config["prefix_lm"]
|
self.prefix_lm = config.attn_config.prefix_lm
|
||||||
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
|
||||||
self.alibi = config.attn_config["alibi"]
|
self.alibi = config.attn_config.alibi
|
||||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||||
if config.init_device == "mixed":
|
if config.init_device == "mixed":
|
||||||
if dist.get_local_rank() == 0:
|
if dist.get_local_rank() == 0:
|
||||||
config.init_device = "cpu"
|
config.init_device = "cpu"
|
||||||
|
@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
hub,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -839,7 +838,9 @@ class FlashCausalLM(Model):
|
|||||||
default_dtype=torch.float16,
|
default_dtype=torch.float16,
|
||||||
aliases=None,
|
aliases=None,
|
||||||
# Used for Santacoder override of config
|
# Used for Santacoder override of config
|
||||||
num_kv_heads=None,
|
num_kv_heads: Optional[int] = None,
|
||||||
|
# Deepseek V2 uses different QK and V dims.
|
||||||
|
head_size: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
@ -922,7 +923,16 @@ class FlashCausalLM(Model):
|
|||||||
else num_kv_heads
|
else num_kv_heads
|
||||||
)
|
)
|
||||||
assert self.num_kv_heads > 0
|
assert self.num_kv_heads > 0
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
|
if head_size is None:
|
||||||
|
# Some models use GQA and different sizes for o_proj
|
||||||
|
# and q_proj, that allows for that.
|
||||||
|
if hasattr(config, "head_dim"):
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
else:
|
||||||
|
self.head_size = head_size
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
@ -1150,31 +1160,36 @@ class FlashCausalLM(Model):
|
|||||||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
|
logger.info,
|
||||||
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
logger.info(
|
log_master(
|
||||||
f"The file {tunableop_filepath} already exists and will be reused."
|
logger.info,
|
||||||
|
f"The file {tunableop_filepath} already exists and will be reused.",
|
||||||
)
|
)
|
||||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
log_master(
|
||||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
logger.info,
|
||||||
|
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||||
|
)
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
@ -1182,7 +1197,9 @@ class FlashCausalLM(Model):
|
|||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
log_master(
|
||||||
|
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||||
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
@ -1534,8 +1551,7 @@ class FlashCausalLM(Model):
|
|||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if n_accepted_ids > 1:
|
||||||
if RANK == 0:
|
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||||
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||||
if FLASH_DECODING:
|
if FLASH_DECODING:
|
||||||
logger.info("Using FLASH_DECODING")
|
log_master(logger.info, "Using FLASH_DECODING")
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
@ -26,11 +27,9 @@ else:
|
|||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
cuda_graphs.sort(reverse=True)
|
cuda_graphs.sort(reverse=True)
|
||||||
|
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
# This is overridden at model loading.
|
# This is overridden at model loading.
|
||||||
global MODEL_ID
|
|
||||||
MODEL_ID = None
|
MODEL_ID = None
|
||||||
|
|
||||||
|
|
||||||
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
|
|||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
global ADAPTER_TO_INDEX
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
|
|||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
AdapterSource,
|
AdapterSource,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
@ -204,8 +205,9 @@ class Model(ABC):
|
|||||||
f"order to use the dynamic adapter loading feature."
|
f"order to use the dynamic adapter loading feature."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
logger.info,
|
||||||
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
|
||||||
)
|
)
|
||||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||||
(
|
(
|
||||||
@ -240,8 +242,9 @@ class Model(ABC):
|
|||||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
if len(unused_weight_names) > 0:
|
||||||
logger.warning(
|
log_master(
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
logger.warning,
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
if adapter_tokenizer is not None:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from itertools import repeat
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(
|
log_master(
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}"
|
logger.info,
|
||||||
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
)
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
@ -261,7 +262,12 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
super().__init__(model_id=model_id, **kwargs)
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
|
@ -56,7 +56,7 @@ def initialize_torch_distributed():
|
|||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=120)
|
||||||
else:
|
else:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
@ -76,7 +76,7 @@ def initialize_torch_distributed():
|
|||||||
backend="ccl",
|
backend="ccl",
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -84,7 +84,7 @@ def initialize_torch_distributed():
|
|||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from text_generation_server.utils.dist import RANK
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
def log_once(log, msg: str):
|
def log_once(log, msg: str, master=True):
|
||||||
log(msg)
|
if master:
|
||||||
|
log_master(log, msg)
|
||||||
|
else:
|
||||||
|
log(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def log_master(log, msg: str):
|
||||||
|
if RANK == 0:
|
||||||
|
log(msg)
|
||||||
|
@ -11,6 +11,7 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Split this config to have a single config type per quant method
|
||||||
@dataclass
|
@dataclass
|
||||||
class _QuantizerConfig:
|
class _QuantizerConfig:
|
||||||
bits: int
|
bits: int
|
||||||
@ -21,6 +22,11 @@ class _QuantizerConfig:
|
|||||||
sym: bool
|
sym: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FP8QuantizerConfig:
|
||||||
|
activation_scale_ub: float
|
||||||
|
|
||||||
|
|
||||||
# We should probably do this with Pytantic JSON deserialization,
|
# We should probably do this with Pytantic JSON deserialization,
|
||||||
# but for now we'll stay close to the old _set_gptq_params.
|
# but for now we'll stay close to the old _set_gptq_params.
|
||||||
def _get_quantizer_config(model_id, revision):
|
def _get_quantizer_config(model_id, revision):
|
||||||
@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
# FP8 config
|
||||||
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||||
|
return _FP8QuantizerConfig(
|
||||||
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
|
)
|
||||||
|
|
||||||
bits = data["quantization_config"]["bits"]
|
bits = data["quantization_config"]["bits"]
|
||||||
groupsize = data["quantization_config"]["group_size"]
|
groupsize = data["quantization_config"]["group_size"]
|
||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
@ -99,6 +112,12 @@ def get_loader(
|
|||||||
if quantize in {"awq", "gptq"}:
|
if quantize in {"awq", "gptq"}:
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return GPTQWeightsLoader(
|
return GPTQWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
desc_act=quantizer_config.desc_act,
|
desc_act=quantizer_config.desc_act,
|
||||||
@ -127,18 +146,28 @@ def get_loader(
|
|||||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||||
|
|
||||||
return Exl2WeightsLoader()
|
return Exl2WeightsLoader()
|
||||||
elif quantize == "fp8":
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
|
|
||||||
return DefaultWeightsLoader(Fp8Weight)
|
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||||
|
|
||||||
|
# TODO: improve check once we have one config type per quantize value
|
||||||
|
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||||
|
)
|
||||||
|
|
||||||
return MarlinWeightsLoader(
|
return MarlinWeightsLoader(
|
||||||
bits=quantizer_config.bits,
|
bits=quantizer_config.bits,
|
||||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||||
)
|
)
|
||||||
elif quantize is None:
|
elif quantize == "fp8" or quantize is None:
|
||||||
return DefaultWeightsLoader(UnquantizedWeight)
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
|
# we need to add this check to not get an attribute error
|
||||||
|
activation_scale_ub = None
|
||||||
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
|
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Type
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +21,13 @@ class WeightsLoader(ABC):
|
|||||||
with the format, etc.
|
with the format, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
@ -77,7 +84,7 @@ class Weight(ABC):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnquantizedWeight:
|
class UnquantizedWeight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
@ -92,7 +99,7 @@ class UnquantizedWeight:
|
|||||||
class DefaultWeightsLoader(WeightsLoader):
|
class DefaultWeightsLoader(WeightsLoader):
|
||||||
"""Weight loader that loads (unquantized) Torch tensors."""
|
"""Weight loader that loads (unquantized) Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, weight_class):
|
def __init__(self, weight_class: Type[UnquantizedWeight]):
|
||||||
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
"""Create a loader. Weights will be wrapped using the given `weights_class`,
|
||||||
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
normally this will be `UnquantizedWeight`, but a quantizer-specific class
|
||||||
such as `Fp8Weight` can be used to quantize the weights during loading.
|
such as `Fp8Weight` can be used to quantize the weights during loading.
|
||||||
@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
and/or concatenation.
|
and/or concatenation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
return weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: "Weights",
|
weights: "Weights",
|
||||||
@ -198,20 +208,31 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True):
|
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. Exl2 uses int16
|
# u4 which are disguised as int32. Exl2 uses int16
|
||||||
# as well.
|
# as well. FP8 uses torch.float8_e4m3fn
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(
|
||||||
|
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
|
||||||
|
):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -231,12 +252,17 @@ class Weights:
|
|||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
if tensor.dtype not in (torch.int16, torch.int32):
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
|
if (
|
||||||
|
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
if to_device:
|
||||||
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -245,10 +271,16 @@ class Weights:
|
|||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(
|
||||||
|
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
|
||||||
|
)
|
||||||
|
|
||||||
def get_packed_sharded(
|
def get_packed_sharded(
|
||||||
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
|
self,
|
||||||
|
tensor_name: str,
|
||||||
|
dim: int,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
to_dtype=True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get a shard from a tensor that packs multiple tensors.
|
Get a shard from a tensor that packs multiple tensors.
|
||||||
@ -294,11 +326,23 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
# Avoid casting quantizer dtypes.
|
# Avoid casting quantizer dtypes.
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]
|
||||||
|
and to_dtype
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def get_weights(self, prefix: str):
|
||||||
|
return self.weights_loader.get_weights(self, prefix)
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(
|
def get_weights_col_packed_qkv(
|
||||||
self,
|
self,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user