mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge branch 'main' into impl-simple-mamba-model
This commit is contained in:
commit
9146ba00a7
12
.github/workflows/delete_doc_comment.yml
vendored
12
.github/workflows/delete_doc_comment.yml
vendored
@ -1,12 +0,0 @@
|
|||||||
name: Delete doc comment
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [ closed ]
|
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
delete:
|
|
||||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
|
||||||
with:
|
|
||||||
pr_number: ${{ github.event.number }}
|
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -2,3 +2,13 @@
|
|||||||
target
|
target
|
||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
# ROCm auto-generated files
|
||||||
|
*.hip
|
||||||
|
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||||
|
server/exllama_kernels/exllama_kernels/hip/
|
||||||
|
server/exllama_kernels/exllama_kernels/hip_func/
|
||||||
|
*_hip.cuh
|
||||||
|
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||||
|
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||||
|
|
||||||
|
420
Cargo.lock
generated
420
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,7 @@ members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "1.3.4"
|
version = "1.4.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
mamba init && \
|
mamba init && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||||
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
@ -104,6 +104,20 @@ WORKDIR /usr/src
|
|||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||||
|
|
||||||
|
# Build exllama kernels
|
||||||
|
FROM kernel-builder as exllama-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
|
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||||
|
|
||||||
|
# Build exllama v2 kernels
|
||||||
|
FROM kernel-builder as exllamav2-kernels-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
|
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||||
|
|
||||||
FROM base as base-copy
|
FROM base as base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
|
|||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from exllama kernels builder
|
||||||
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
41
README.md
41
README.md
@ -28,7 +28,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
|||||||
- [Local Install](#local-install)
|
- [Local Install](#local-install)
|
||||||
- [CUDA Kernels](#cuda-kernels)
|
- [CUDA Kernels](#cuda-kernels)
|
||||||
- [Optimized architectures](#optimized-architectures)
|
- [Optimized architectures](#optimized-architectures)
|
||||||
- [Run Falcon](#run-falcon)
|
- [Run Mistral](#run-a-model)
|
||||||
- [Run](#run)
|
- [Run](#run)
|
||||||
- [Quantization](#quantization)
|
- [Quantization](#quantization)
|
||||||
- [Develop](#develop)
|
- [Develop](#develop)
|
||||||
@ -42,7 +42,11 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
|||||||
- Token streaming using Server-Sent Events (SSE)
|
- Token streaming using Server-Sent Events (SSE)
|
||||||
- Continuous batching of incoming requests for increased total throughput
|
- Continuous batching of incoming requests for increased total throughput
|
||||||
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
|
- Quantization with :
|
||||||
|
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
|
- [GPT-Q](https://arxiv.org/abs/2210.17323)
|
||||||
|
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
|
||||||
|
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||||
@ -51,6 +55,14 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
|||||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
|
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
|
||||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
|
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
|
||||||
|
|
||||||
|
### Hardware support
|
||||||
|
|
||||||
|
- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
|
||||||
|
- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
|
||||||
|
- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
|
||||||
|
- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
|
||||||
|
- [Gaudi](https://github.com/huggingface/tgi-gaudi)
|
||||||
|
|
||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
@ -62,7 +74,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
|
|||||||
model=HuggingFaceH4/zephyr-7b-beta
|
model=HuggingFaceH4/zephyr-7b-beta
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
@ -76,7 +88,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
|
|
||||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model` instead of the command above.
|
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
@ -106,7 +118,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
|
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
@ -154,7 +166,7 @@ Python 3.9, e.g. using `conda`:
|
|||||||
```shell
|
```shell
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||||
|
|
||||||
conda create -n text-generation-inference python=3.9
|
conda create -n text-generation-inference python=3.11
|
||||||
conda activate text-generation-inference
|
conda activate text-generation-inference
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -180,7 +192,7 @@ Then run:
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
||||||
make run-falcon-7b-instruct
|
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
|
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
|
||||||
@ -189,16 +201,9 @@ make run-falcon-7b-instruct
|
|||||||
sudo apt-get install libssl-dev gcc -y
|
sudo apt-get install libssl-dev gcc -y
|
||||||
```
|
```
|
||||||
|
|
||||||
### CUDA Kernels
|
|
||||||
|
|
||||||
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
|
|
||||||
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
|
||||||
|
|
||||||
Be aware that the official Docker image has them enabled by default.
|
|
||||||
|
|
||||||
## Optimized architectures
|
## Optimized architectures
|
||||||
|
|
||||||
TGI works out of the box to serve optimized models in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
|
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
|
||||||
|
|
||||||
Other architectures are supported on a best-effort basis using:
|
Other architectures are supported on a best-effort basis using:
|
||||||
|
|
||||||
@ -210,12 +215,12 @@ or
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Run Falcon
|
## Run locally
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-falcon-7b-instruct
|
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
@ -223,7 +228,7 @@ make run-falcon-7b-instruct
|
|||||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-falcon-7b-instruct-quantize
|
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
|
||||||
```
|
```
|
||||||
|
|
||||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "1.3.4"
|
"version": "1.4.0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
@ -342,6 +342,135 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"/tokenize": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Tokenize inputs",
|
||||||
|
"description": "Tokenize inputs",
|
||||||
|
"operationId": "tokenize",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GenerateRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Tokenized ids",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/TokenizeResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "No tokenizer found",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "No fast tokenizer available"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/v1/chat/completions": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens",
|
||||||
|
"description": "Generate tokens",
|
||||||
|
"operationId": "chat_completions",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ChatRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Text",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"components": {
|
"components": {
|
||||||
@ -399,6 +528,226 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ChatCompletion": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"object",
|
||||||
|
"created",
|
||||||
|
"model",
|
||||||
|
"system_fingerprint",
|
||||||
|
"choices",
|
||||||
|
"usage"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"choices": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionComplete"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"created": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": "1706270835",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"system_fingerprint": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"$ref": "#/components/schemas/Usage"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ChatCompletionChoice": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"index",
|
||||||
|
"delta"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"delta": {
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionDelta"
|
||||||
|
},
|
||||||
|
"finish_reason": {
|
||||||
|
"type": "string",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"logprobs": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"nullable": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ChatCompletionChunk": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"object",
|
||||||
|
"created",
|
||||||
|
"model",
|
||||||
|
"system_fingerprint",
|
||||||
|
"choices"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"choices": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionChoice"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"created": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": "1706270978",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"system_fingerprint": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ChatCompletionDelta": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"role",
|
||||||
|
"content"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "What is Deep Learning?"
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ChatRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"model"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"frequency_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
|
||||||
|
"example": "1.0",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"logit_bias": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float"
|
||||||
|
},
|
||||||
|
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"logprobs": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.",
|
||||||
|
"example": "false",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
||||||
|
"example": "32",
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Message"
|
||||||
|
},
|
||||||
|
"description": "A list of messages comprising the conversation so far."
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
|
||||||
|
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
},
|
||||||
|
"n": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.",
|
||||||
|
"example": "2",
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"presence_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
|
||||||
|
"example": 0.1,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 42,
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.",
|
||||||
|
"example": 1.0,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"top_logprobs": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
|
||||||
|
"example": "5",
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||||
|
"example": 0.95,
|
||||||
|
"nullable": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"CompatGenerateRequest": {
|
"CompatGenerateRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -494,7 +843,8 @@
|
|||||||
"length",
|
"length",
|
||||||
"eos_token",
|
"eos_token",
|
||||||
"stop_sequence"
|
"stop_sequence"
|
||||||
]
|
],
|
||||||
|
"example": "Length"
|
||||||
},
|
},
|
||||||
"GenerateParameters": {
|
"GenerateParameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -523,7 +873,7 @@
|
|||||||
"max_new_tokens": {
|
"max_new_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"default": "20",
|
"default": "100",
|
||||||
"example": "20",
|
"example": "20",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
@ -758,6 +1108,23 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"Message": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"role",
|
||||||
|
"content"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "My name is David and I"
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"PrefillToken": {
|
"PrefillToken": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -784,6 +1151,37 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"SimpleToken": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"text",
|
||||||
|
"start",
|
||||||
|
"stop"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 0,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"start": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 0,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 2,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"StreamDetails": {
|
"StreamDetails": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -812,6 +1210,7 @@
|
|||||||
"StreamResponse": {
|
"StreamResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
"index",
|
||||||
"token"
|
"token"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -830,6 +1229,11 @@
|
|||||||
"example": "test",
|
"example": "test",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"token": {
|
"token": {
|
||||||
"$ref": "#/components/schemas/Token"
|
"$ref": "#/components/schemas/Token"
|
||||||
},
|
},
|
||||||
@ -871,6 +1275,12 @@
|
|||||||
"example": "test"
|
"example": "test"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"TokenizeResponse": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/SimpleToken"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
title: Installation
|
title: Installation
|
||||||
- local: supported_models
|
- local: supported_models
|
||||||
title: Supported Models and Hardware
|
title: Supported Models and Hardware
|
||||||
|
- local: messages_api
|
||||||
|
title: Messages API
|
||||||
title: Getting started
|
title: Getting started
|
||||||
- sections:
|
- sections:
|
||||||
- local: basic_tutorials/consuming_tgi
|
- local: basic_tutorials/consuming_tgi
|
||||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
|||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
-e HUGGING_FACE_HUB_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
@ -60,9 +60,9 @@ Options:
|
|||||||
[env: QUANTIZE=]
|
[env: QUANTIZE=]
|
||||||
|
|
||||||
Possible values:
|
Possible values:
|
||||||
- awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency
|
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
|
||||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||||
@ -354,6 +354,14 @@ Options:
|
|||||||
|
|
||||||
[env: NGROK_EDGE=]
|
[env: NGROK_EDGE=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## TOKENIZER_CONFIG_PATH
|
||||||
|
```shell
|
||||||
|
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
|
||||||
|
The path to the tokenizer config file. This path is used to load the tokenizer configuration which may include a `chat_template`. If not provided, the default config will be used from the model hub
|
||||||
|
|
||||||
|
[env: TOKENIZER_CONFIG_PATH=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## ENV
|
## ENV
|
||||||
```shell
|
```shell
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Using TGI CLI
|
# Using TGI CLI
|
||||||
|
|
||||||
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli).
|
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).
|
||||||
|
|
||||||
`text-generation-server` lets you download the model with `download-weights` command like below 👇
|
`text-generation-server` lets you download the model with `download-weights` command like below 👇
|
||||||
|
|
||||||
|
175
docs/source/messages_api.md
Normal file
175
docs/source/messages_api.md
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# Messages API
|
||||||
|
|
||||||
|
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
|
||||||
|
|
||||||
|
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
|
||||||
|
|
||||||
|
#### Table of Contents
|
||||||
|
|
||||||
|
- [Making a Request](#making-a-request)
|
||||||
|
- [Streaming](#streaming)
|
||||||
|
- [Synchronous](#synchronous)
|
||||||
|
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
|
||||||
|
- [Cloud Providers](#cloud-providers)
|
||||||
|
- [Amazon SageMaker](#amazon-sagemaker)
|
||||||
|
|
||||||
|
## Making a Request
|
||||||
|
|
||||||
|
You can make a request to TGI's Messages API using `curl`. Here's an example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl localhost:3000/v1/chat/completions \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is deep learning?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 20
|
||||||
|
}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Streaming
|
||||||
|
|
||||||
|
You can also use OpenAI's Python client library to make a streaming request. Here's how:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# init the client but point it to TGI
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:3000/v1",
|
||||||
|
api_key="-"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate and print stream
|
||||||
|
for message in chat_completion:
|
||||||
|
print(message)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Synchronous
|
||||||
|
|
||||||
|
If you prefer to make a synchronous request, you can do so like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# init the client but point it to TGI
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:3000/v1",
|
||||||
|
api_key="-"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
],
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(chat_completion)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Hugging Face Inference Endpoints
|
||||||
|
|
||||||
|
The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).
|
||||||
|
Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:
|
||||||
|
|
||||||
|
> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# init the client but point it to TGI
|
||||||
|
client = OpenAI(
|
||||||
|
# replace with your endpoint url, make sure to include "v1/" at the end
|
||||||
|
base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
||||||
|
# replace with your API key
|
||||||
|
api_key="hf_XXX"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate and print stream
|
||||||
|
for message in chat_completion:
|
||||||
|
print(message.choices[0].delta.content, end="")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cloud Providers
|
||||||
|
|
||||||
|
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
|
||||||
|
|
||||||
|
## Amazon SageMaker
|
||||||
|
|
||||||
|
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
|
||||||
|
|
||||||
|
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
import sagemaker
|
||||||
|
import boto3
|
||||||
|
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
|
||||||
|
|
||||||
|
try:
|
||||||
|
role = sagemaker.get_execution_role()
|
||||||
|
except ValueError:
|
||||||
|
iam = boto3.client('iam')
|
||||||
|
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
|
||||||
|
|
||||||
|
# Hub Model configuration. https://huggingface.co/models
|
||||||
|
hub = {
|
||||||
|
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
||||||
|
'SM_NUM_GPUS': json.dumps(1),
|
||||||
|
'MESSAGES_API_ENABLED': True
|
||||||
|
}
|
||||||
|
|
||||||
|
# create Hugging Face Model Class
|
||||||
|
huggingface_model = HuggingFaceModel(
|
||||||
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
||||||
|
env=hub,
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
|
||||||
|
# deploy model to SageMaker Inference
|
||||||
|
predictor = huggingface_model.deploy(
|
||||||
|
initial_instance_count=1,
|
||||||
|
instance_type="ml.g5.2xlarge",
|
||||||
|
container_startup_health_check_timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
# send request
|
||||||
|
predictor.predict({
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
```
|
@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
|
|||||||
model=tiiuae/falcon-7b-instruct
|
model=tiiuae/falcon-7b-instruct
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
|
|||||||
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
|
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model
|
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||||
@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:1.3 --help
|
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
@ -19,7 +19,9 @@ The following models are optimized and can be served with TGI, which uses custom
|
|||||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||||
- [Llama V2](https://huggingface.co/meta-llama)
|
- [Llama V2](https://huggingface.co/meta-llama)
|
||||||
- [Code Llama](https://huggingface.co/codellama)
|
- [Code Llama](https://huggingface.co/codellama)
|
||||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||||
|
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||||
|
- [Phi](https://huggingface.co/microsoft/phi-2)
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
|
||||||
@ -41,8 +43,8 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
|
|||||||
|
|
||||||
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
|
||||||
|
|
||||||
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
* Quantization (GPTQ, AWQ, etc.)
|
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
|
||||||
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
|
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
|
||||||
* Kernel for slinding window attention (Mistral)
|
* Kernel for slinding window attention (Mistral)
|
||||||
|
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1391,
|
||||||
|
"logprob": -0.98779297,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25927,
|
||||||
|
"logprob": -0.76660156,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -0.7246094,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4943,
|
||||||
|
"logprob": -0.41333008,
|
||||||
|
"special": false,
|
||||||
|
"text": "\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -0.11785889,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50280,
|
||||||
|
"logprob": -0.97265625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26209,
|
||||||
|
"logprob": -1.4414062,
|
||||||
|
"special": false,
|
||||||
|
"text": "response"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 796,
|
||||||
|
"logprob": -0.0569458,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2116,
|
||||||
|
"logprob": -1.1533203,
|
||||||
|
"special": false,
|
||||||
|
"text": " self"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": {request}\")\n response = self"
|
||||||
|
}
|
@ -0,0 +1,60 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "stop_sequence",
|
||||||
|
"generated_tokens": 6,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 284,
|
||||||
|
"logprob": -0.19421387,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3758,
|
||||||
|
"logprob": -0.62597656,
|
||||||
|
"special": false,
|
||||||
|
"text": " send"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1366,
|
||||||
|
"logprob": -0.87060547,
|
||||||
|
"special": false,
|
||||||
|
"text": " data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 625,
|
||||||
|
"logprob": -0.88427734,
|
||||||
|
"special": false,
|
||||||
|
"text": " over"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 257,
|
||||||
|
"logprob": -1.0830078,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3127,
|
||||||
|
"logprob": -1.9462891,
|
||||||
|
"special": false,
|
||||||
|
"text": " network"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request to send data over a network"
|
||||||
|
}
|
@ -0,0 +1,338 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1391,
|
||||||
|
"logprob": -0.98779297,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25927,
|
||||||
|
"logprob": -0.7729492,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -0.7241211,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4943,
|
||||||
|
"logprob": -0.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -0.119018555,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50280,
|
||||||
|
"logprob": -0.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26209,
|
||||||
|
"logprob": -1.4414062,
|
||||||
|
"special": false,
|
||||||
|
"text": "response"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 796,
|
||||||
|
"logprob": -0.056854248,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2116,
|
||||||
|
"logprob": -1.1533203,
|
||||||
|
"special": false,
|
||||||
|
"text": " self"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": {request}\")\n response = self"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1391,
|
||||||
|
"logprob": -0.98779297,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25927,
|
||||||
|
"logprob": -0.7729492,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -0.7241211,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4943,
|
||||||
|
"logprob": -0.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -0.119018555,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50280,
|
||||||
|
"logprob": -0.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26209,
|
||||||
|
"logprob": -1.4414062,
|
||||||
|
"special": false,
|
||||||
|
"text": "response"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 796,
|
||||||
|
"logprob": -0.056854248,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2116,
|
||||||
|
"logprob": -1.1533203,
|
||||||
|
"special": false,
|
||||||
|
"text": " self"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": {request}\")\n response = self"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1391,
|
||||||
|
"logprob": -0.98779297,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25927,
|
||||||
|
"logprob": -0.7729492,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -0.7241211,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4943,
|
||||||
|
"logprob": -0.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -0.119018555,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50280,
|
||||||
|
"logprob": -0.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26209,
|
||||||
|
"logprob": -1.4414062,
|
||||||
|
"special": false,
|
||||||
|
"text": "response"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 796,
|
||||||
|
"logprob": -0.056854248,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2116,
|
||||||
|
"logprob": -1.1533203,
|
||||||
|
"special": false,
|
||||||
|
"text": " self"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": {request}\")\n response = self"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 14402,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2581,
|
||||||
|
"logprob": -11.6171875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1391,
|
||||||
|
"logprob": -0.98779297,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25927,
|
||||||
|
"logprob": -0.7729492,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -0.7241211,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4943,
|
||||||
|
"logprob": -0.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 198,
|
||||||
|
"logprob": -0.119018555,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50280,
|
||||||
|
"logprob": -0.9707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26209,
|
||||||
|
"logprob": -1.4414062,
|
||||||
|
"special": false,
|
||||||
|
"text": "response"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 796,
|
||||||
|
"logprob": -0.056854248,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2116,
|
||||||
|
"logprob": -1.1533203,
|
||||||
|
"special": false,
|
||||||
|
"text": " self"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": {request}\")\n response = self"
|
||||||
|
}
|
||||||
|
]
|
@ -16,52 +16,52 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.09375,
|
"logprob": -9.0859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25976562,
|
"logprob": -0.25830078,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.2148438,
|
"logprob": -2.1875,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.3010254,
|
"logprob": -0.30004883,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.6757812,
|
"logprob": -5.6171875,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0898438,
|
"logprob": -3.078125,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6791992,
|
"logprob": -0.68066406,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.38891602,
|
"logprob": -0.38745117,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.92041016,
|
"logprob": -0.9453125,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.5390625,
|
"logprob": -2.5371094,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -69,7 +69,7 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": 0.0,
|
"logprob": -0.051635742,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
@ -81,7 +81,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11665,
|
"id": 11665,
|
||||||
"logprob": -1.6005859,
|
"logprob": -1.2236328,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " reduce"
|
"text": " reduce"
|
||||||
},
|
},
|
||||||
@ -159,7 +159,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 203,
|
"id": 203,
|
||||||
"logprob": -0.11968994,
|
"logprob": -0.12695312,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
|
@ -11,92 +11,92 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4911,
|
"id": 4911,
|
||||||
"logprob": -5.7851562,
|
"logprob": -6.9765625,
|
||||||
"text": "User"
|
"text": "User"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -0.006996155,
|
"logprob": -0.0059432983,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -0.81347656,
|
"logprob": -0.8408203,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32001,
|
"id": 32001,
|
||||||
"logprob": -6.687641e-05,
|
"logprob": -9.906292e-05,
|
||||||
"text": "<image>"
|
"text": "<image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -3.5762787e-07,
|
"logprob": -2.3841858e-07,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1815,
|
"id": 1815,
|
||||||
"logprob": -4.2148438,
|
"logprob": -4.1679688,
|
||||||
"text": "Can"
|
"text": "Can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -0.014137268,
|
"logprob": -0.014099121,
|
||||||
"text": "you"
|
"text": "you"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2649,
|
"id": 2649,
|
||||||
"logprob": -4.4335938,
|
"logprob": -4.4609375,
|
||||||
"text": "tell"
|
"text": "tell"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 592,
|
"id": 592,
|
||||||
"logprob": -0.2919922,
|
"logprob": -0.29882812,
|
||||||
"text": "me"
|
"text": "me"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -4.2070312,
|
"logprob": -4.1445312,
|
||||||
"text": "a"
|
"text": "a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1407,
|
"id": 1407,
|
||||||
"logprob": -9.421875,
|
"logprob": -9.3828125,
|
||||||
"text": "very"
|
"text": "very"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3273,
|
"id": 3273,
|
||||||
"logprob": -1.8720703,
|
"logprob": -1.9736328,
|
||||||
"text": "short"
|
"text": "short"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5828,
|
"id": 5828,
|
||||||
"logprob": -0.26489258,
|
"logprob": -0.2800293,
|
||||||
"text": "story"
|
"text": "story"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2729,
|
"id": 2729,
|
||||||
"logprob": -3.7441406,
|
"logprob": -3.5625,
|
||||||
"text": "based"
|
"text": "based"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 373,
|
"id": 373,
|
||||||
"logprob": -0.0005393028,
|
"logprob": -0.0006427765,
|
||||||
"text": "on"
|
"text": "on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"logprob": -0.140625,
|
"logprob": -0.13952637,
|
||||||
"text": "the"
|
"text": "the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1967,
|
"id": 1967,
|
||||||
"logprob": -0.06756592,
|
"logprob": -0.068115234,
|
||||||
"text": "image"
|
"text": "image"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.15454102,
|
"logprob": -0.16357422,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -104,25 +104,25 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.0019140244,
|
"logprob": -0.0026474,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29871,
|
"id": 29871,
|
||||||
"logprob": -8.404255e-05,
|
"logprob": -8.547306e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7642975e-05,
|
"logprob": -1.7881393e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 7900,
|
"id": 7900,
|
||||||
"logprob": -2.9802322e-06,
|
"logprob": -3.0994415e-06,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Ass"
|
"text": "Ass"
|
||||||
},
|
},
|
||||||
@ -140,30 +140,29 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": -0.91064453,
|
"logprob": -0.92529297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " A"
|
"text": " A"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 696,
|
"id": 696,
|
||||||
"logprob": -1.2412109,
|
"logprob": -1.1269531,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ro"
|
"text": " ro"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15664,
|
"id": 15664,
|
||||||
"logprob": -0.0002439022,
|
"logprob": -0.00029492378,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "oster"
|
"text": "oster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15028,
|
"id": 15028,
|
||||||
"logprob": -1.1630859,
|
"logprob": -1.1855469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " stands"
|
"text": " stands"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"top_tokens": null
|
|
||||||
},
|
},
|
||||||
"generated_text": " \nAssistant: A rooster stands"
|
"generated_text": " \nAssistant: A rooster stands"
|
||||||
}
|
}
|
||||||
|
@ -12,92 +12,92 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4911,
|
"id": 4911,
|
||||||
"logprob": -5.7851562,
|
"logprob": -6.9804688,
|
||||||
"text": "User"
|
"text": "User"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -0.006996155,
|
"logprob": -0.006122589,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -0.81347656,
|
"logprob": -0.8417969,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32001,
|
"id": 32001,
|
||||||
"logprob": -6.687641e-05,
|
"logprob": -9.918213e-05,
|
||||||
"text": "<image>"
|
"text": "<image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -3.5762787e-07,
|
"logprob": -2.3841858e-07,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1815,
|
"id": 1815,
|
||||||
"logprob": -4.2148438,
|
"logprob": -4.1679688,
|
||||||
"text": "Can"
|
"text": "Can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -0.014137268,
|
"logprob": -0.014091492,
|
||||||
"text": "you"
|
"text": "you"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2649,
|
"id": 2649,
|
||||||
"logprob": -4.4335938,
|
"logprob": -4.4726562,
|
||||||
"text": "tell"
|
"text": "tell"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 592,
|
"id": 592,
|
||||||
"logprob": -0.2919922,
|
"logprob": -0.2998047,
|
||||||
"text": "me"
|
"text": "me"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -4.2070312,
|
"logprob": -4.15625,
|
||||||
"text": "a"
|
"text": "a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1407,
|
"id": 1407,
|
||||||
"logprob": -9.421875,
|
"logprob": -9.3828125,
|
||||||
"text": "very"
|
"text": "very"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3273,
|
"id": 3273,
|
||||||
"logprob": -1.8720703,
|
"logprob": -1.9716797,
|
||||||
"text": "short"
|
"text": "short"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5828,
|
"id": 5828,
|
||||||
"logprob": -0.26489258,
|
"logprob": -0.27734375,
|
||||||
"text": "story"
|
"text": "story"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2729,
|
"id": 2729,
|
||||||
"logprob": -3.7441406,
|
"logprob": -3.5605469,
|
||||||
"text": "based"
|
"text": "based"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 373,
|
"id": 373,
|
||||||
"logprob": -0.0005393028,
|
"logprob": -0.00064468384,
|
||||||
"text": "on"
|
"text": "on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"logprob": -0.140625,
|
"logprob": -0.14160156,
|
||||||
"text": "the"
|
"text": "the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1967,
|
"id": 1967,
|
||||||
"logprob": -0.06756592,
|
"logprob": -0.06915283,
|
||||||
"text": "image"
|
"text": "image"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.15454102,
|
"logprob": -0.16381836,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -105,19 +105,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.0019140244,
|
"logprob": -0.0026664734,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29871,
|
"id": 29871,
|
||||||
"logprob": -8.392334e-05,
|
"logprob": -8.583069e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7881393e-05,
|
"logprob": -1.8119812e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
@ -135,36 +135,35 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -3.0994415e-06,
|
"logprob": -3.2186508e-06,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": -0.9057617,
|
"logprob": -0.9301758,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " A"
|
"text": " A"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 696,
|
"id": 696,
|
||||||
"logprob": -1.2294922,
|
"logprob": -1.1279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ro"
|
"text": " ro"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15664,
|
"id": 15664,
|
||||||
"logprob": -0.00024533272,
|
"logprob": -0.0002939701,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "oster"
|
"text": "oster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15028,
|
"id": 15028,
|
||||||
"logprob": -1.1640625,
|
"logprob": -1.1865234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " stands"
|
"text": " stands"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"top_tokens": null
|
|
||||||
},
|
},
|
||||||
"generated_text": " \nAssistant: A rooster stands"
|
"generated_text": " \nAssistant: A rooster stands"
|
||||||
},
|
},
|
||||||
@ -181,92 +180,92 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4911,
|
"id": 4911,
|
||||||
"logprob": -5.7773438,
|
"logprob": -6.9804688,
|
||||||
"text": "User"
|
"text": "User"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -0.0070114136,
|
"logprob": -0.006122589,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -0.8208008,
|
"logprob": -0.8417969,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32001,
|
"id": 32001,
|
||||||
"logprob": -6.699562e-05,
|
"logprob": -9.942055e-05,
|
||||||
"text": "<image>"
|
"text": "<image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -3.5762787e-07,
|
"logprob": -2.3841858e-07,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1815,
|
"id": 1815,
|
||||||
"logprob": -4.2265625,
|
"logprob": -4.1679688,
|
||||||
"text": "Can"
|
"text": "Can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -0.014175415,
|
"logprob": -0.014091492,
|
||||||
"text": "you"
|
"text": "you"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2649,
|
"id": 2649,
|
||||||
"logprob": -4.4296875,
|
"logprob": -4.4726562,
|
||||||
"text": "tell"
|
"text": "tell"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 592,
|
"id": 592,
|
||||||
"logprob": -0.29516602,
|
"logprob": -0.2998047,
|
||||||
"text": "me"
|
"text": "me"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -4.2109375,
|
"logprob": -4.15625,
|
||||||
"text": "a"
|
"text": "a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1407,
|
"id": 1407,
|
||||||
"logprob": -9.4296875,
|
"logprob": -9.3828125,
|
||||||
"text": "very"
|
"text": "very"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3273,
|
"id": 3273,
|
||||||
"logprob": -1.8720703,
|
"logprob": -1.9716797,
|
||||||
"text": "short"
|
"text": "short"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5828,
|
"id": 5828,
|
||||||
"logprob": -0.26879883,
|
"logprob": -0.27734375,
|
||||||
"text": "story"
|
"text": "story"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2729,
|
"id": 2729,
|
||||||
"logprob": -3.7675781,
|
"logprob": -3.5605469,
|
||||||
"text": "based"
|
"text": "based"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 373,
|
"id": 373,
|
||||||
"logprob": -0.0005354881,
|
"logprob": -0.0006451607,
|
||||||
"text": "on"
|
"text": "on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"logprob": -0.13671875,
|
"logprob": -0.14160156,
|
||||||
"text": "the"
|
"text": "the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1967,
|
"id": 1967,
|
||||||
"logprob": -0.06719971,
|
"logprob": -0.06915283,
|
||||||
"text": "image"
|
"text": "image"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.15551758,
|
"logprob": -0.16381836,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -274,19 +273,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.0019130707,
|
"logprob": -0.0026664734,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29871,
|
"id": 29871,
|
||||||
"logprob": -8.392334e-05,
|
"logprob": -8.571148e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7881393e-05,
|
"logprob": -1.8119812e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
@ -310,30 +309,29 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": -0.9013672,
|
"logprob": -0.9301758,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " A"
|
"text": " A"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 696,
|
"id": 696,
|
||||||
"logprob": -1.2324219,
|
"logprob": -1.1279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ro"
|
"text": " ro"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15664,
|
"id": 15664,
|
||||||
"logprob": -0.0002477169,
|
"logprob": -0.0002939701,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "oster"
|
"text": "oster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15028,
|
"id": 15028,
|
||||||
"logprob": -1.1660156,
|
"logprob": -1.1865234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " stands"
|
"text": " stands"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"top_tokens": null
|
|
||||||
},
|
},
|
||||||
"generated_text": " \nAssistant: A rooster stands"
|
"generated_text": " \nAssistant: A rooster stands"
|
||||||
},
|
},
|
||||||
@ -350,92 +348,92 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4911,
|
"id": 4911,
|
||||||
"logprob": -5.7773438,
|
"logprob": -6.9804688,
|
||||||
"text": "User"
|
"text": "User"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -0.0070114136,
|
"logprob": -0.006122589,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -0.8208008,
|
"logprob": -0.8417969,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32001,
|
"id": 32001,
|
||||||
"logprob": -6.699562e-05,
|
"logprob": -9.918213e-05,
|
||||||
"text": "<image>"
|
"text": "<image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -3.5762787e-07,
|
"logprob": -2.3841858e-07,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1815,
|
"id": 1815,
|
||||||
"logprob": -4.2265625,
|
"logprob": -4.1679688,
|
||||||
"text": "Can"
|
"text": "Can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -0.014175415,
|
"logprob": -0.014091492,
|
||||||
"text": "you"
|
"text": "you"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2649,
|
"id": 2649,
|
||||||
"logprob": -4.4296875,
|
"logprob": -4.4726562,
|
||||||
"text": "tell"
|
"text": "tell"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 592,
|
"id": 592,
|
||||||
"logprob": -0.29516602,
|
"logprob": -0.2998047,
|
||||||
"text": "me"
|
"text": "me"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -4.2109375,
|
"logprob": -4.15625,
|
||||||
"text": "a"
|
"text": "a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1407,
|
"id": 1407,
|
||||||
"logprob": -9.4296875,
|
"logprob": -9.3828125,
|
||||||
"text": "very"
|
"text": "very"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3273,
|
"id": 3273,
|
||||||
"logprob": -1.8720703,
|
"logprob": -1.9716797,
|
||||||
"text": "short"
|
"text": "short"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5828,
|
"id": 5828,
|
||||||
"logprob": -0.26879883,
|
"logprob": -0.27734375,
|
||||||
"text": "story"
|
"text": "story"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2729,
|
"id": 2729,
|
||||||
"logprob": -3.7675781,
|
"logprob": -3.5605469,
|
||||||
"text": "based"
|
"text": "based"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 373,
|
"id": 373,
|
||||||
"logprob": -0.0005354881,
|
"logprob": -0.00064468384,
|
||||||
"text": "on"
|
"text": "on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"logprob": -0.13671875,
|
"logprob": -0.14160156,
|
||||||
"text": "the"
|
"text": "the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1967,
|
"id": 1967,
|
||||||
"logprob": -0.06719971,
|
"logprob": -0.06915283,
|
||||||
"text": "image"
|
"text": "image"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.15551758,
|
"logprob": -0.16381836,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -443,19 +441,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.001912117,
|
"logprob": -0.0026664734,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29871,
|
"id": 29871,
|
||||||
"logprob": -8.392334e-05,
|
"logprob": -8.59499e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7762184e-05,
|
"logprob": -1.8119812e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
@ -479,30 +477,29 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": -0.9013672,
|
"logprob": -0.9301758,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " A"
|
"text": " A"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 696,
|
"id": 696,
|
||||||
"logprob": -1.2324219,
|
"logprob": -1.1279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ro"
|
"text": " ro"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15664,
|
"id": 15664,
|
||||||
"logprob": -0.0002477169,
|
"logprob": -0.0002939701,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "oster"
|
"text": "oster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15028,
|
"id": 15028,
|
||||||
"logprob": -1.1660156,
|
"logprob": -1.1865234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " stands"
|
"text": " stands"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"top_tokens": null
|
|
||||||
},
|
},
|
||||||
"generated_text": " \nAssistant: A rooster stands"
|
"generated_text": " \nAssistant: A rooster stands"
|
||||||
},
|
},
|
||||||
@ -519,92 +516,92 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4911,
|
"id": 4911,
|
||||||
"logprob": -5.7773438,
|
"logprob": -6.9804688,
|
||||||
"text": "User"
|
"text": "User"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -0.0070114136,
|
"logprob": -0.006122589,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -0.8208008,
|
"logprob": -0.8417969,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32001,
|
"id": 32001,
|
||||||
"logprob": -6.699562e-05,
|
"logprob": -9.942055e-05,
|
||||||
"text": "<image>"
|
"text": "<image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32000,
|
"id": 32000,
|
||||||
"logprob": -3.5762787e-07,
|
"logprob": -2.3841858e-07,
|
||||||
"text": "<fake_token_around_image>"
|
"text": "<fake_token_around_image>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1815,
|
"id": 1815,
|
||||||
"logprob": -4.2265625,
|
"logprob": -4.1679688,
|
||||||
"text": "Can"
|
"text": "Can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -0.014175415,
|
"logprob": -0.014091492,
|
||||||
"text": "you"
|
"text": "you"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2649,
|
"id": 2649,
|
||||||
"logprob": -4.4296875,
|
"logprob": -4.4726562,
|
||||||
"text": "tell"
|
"text": "tell"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 592,
|
"id": 592,
|
||||||
"logprob": -0.29516602,
|
"logprob": -0.2998047,
|
||||||
"text": "me"
|
"text": "me"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -4.2109375,
|
"logprob": -4.15625,
|
||||||
"text": "a"
|
"text": "a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1407,
|
"id": 1407,
|
||||||
"logprob": -9.4296875,
|
"logprob": -9.3828125,
|
||||||
"text": "very"
|
"text": "very"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3273,
|
"id": 3273,
|
||||||
"logprob": -1.8720703,
|
"logprob": -1.9716797,
|
||||||
"text": "short"
|
"text": "short"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5828,
|
"id": 5828,
|
||||||
"logprob": -0.26879883,
|
"logprob": -0.27734375,
|
||||||
"text": "story"
|
"text": "story"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2729,
|
"id": 2729,
|
||||||
"logprob": -3.7675781,
|
"logprob": -3.5605469,
|
||||||
"text": "based"
|
"text": "based"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 373,
|
"id": 373,
|
||||||
"logprob": -0.0005354881,
|
"logprob": -0.0006451607,
|
||||||
"text": "on"
|
"text": "on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"logprob": -0.13671875,
|
"logprob": -0.14160156,
|
||||||
"text": "the"
|
"text": "the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1967,
|
"id": 1967,
|
||||||
"logprob": -0.06719971,
|
"logprob": -0.06915283,
|
||||||
"text": "image"
|
"text": "image"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.15551758,
|
"logprob": -0.16381836,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -612,19 +609,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.001912117,
|
"logprob": -0.0026664734,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29871,
|
"id": 29871,
|
||||||
"logprob": -8.392334e-05,
|
"logprob": -8.571148e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7762184e-05,
|
"logprob": -1.8119812e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
@ -648,30 +645,29 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": -0.9013672,
|
"logprob": -0.9301758,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " A"
|
"text": " A"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 696,
|
"id": 696,
|
||||||
"logprob": -1.2324219,
|
"logprob": -1.1279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ro"
|
"text": " ro"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15664,
|
"id": 15664,
|
||||||
"logprob": -0.0002477169,
|
"logprob": -0.0002939701,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "oster"
|
"text": "oster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15028,
|
"id": 15028,
|
||||||
"logprob": -1.1660156,
|
"logprob": -1.1865234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " stands"
|
"text": " stands"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"top_tokens": null
|
|
||||||
},
|
},
|
||||||
"generated_text": " \nAssistant: A rooster stands"
|
"generated_text": " \nAssistant: A rooster stands"
|
||||||
}
|
}
|
||||||
|
63
integration-tests/models/test_flash_phi.py
Normal file
63
integration-tests/models/test_flash_phi.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_phi_handle(launcher):
|
||||||
|
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_phi(flash_phi_handle):
|
||||||
|
await flash_phi_handle.health(300)
|
||||||
|
return flash_phi_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
|
response = await flash_phi.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == ': {request}")\n response = self'
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
|
response = await flash_phi.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["network"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 6
|
||||||
|
assert response.generated_text == "Test request to send data over a network"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert responses[0].generated_text == ': {request}")\n response = self'
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-integration-tests"
|
name = "text-generation-integration-tests"
|
||||||
version = "1.3.4"
|
version = "1.4.0"
|
||||||
description = "Text Generation Inference integration tests"
|
description = "Text Generation Inference integration tests"
|
||||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ use nix::unistd::Pid;
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Lines, Read};
|
use std::io::{BufRead, BufReader, Lines};
|
||||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||||
@ -21,16 +21,16 @@ mod env_runtime;
|
|||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Quantization {
|
enum Quantization {
|
||||||
/// 4 bit quantization. Requires a specific GTPQ quantized model:
|
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||||
/// https://hf.co/models?search=awq.
|
/// https://hf.co/models?search=awq.
|
||||||
/// Should replace GPTQ models whereever possible because of the better latency
|
/// Should replace GPTQ models wherever possible because of the better latency
|
||||||
Awq,
|
Awq,
|
||||||
/// 8 bit quantization, doesn't require specific model.
|
/// 8 bit quantization, doesn't require specific model.
|
||||||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||||
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||||
Eetq,
|
Eetq,
|
||||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
|
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
|
||||||
/// text-generation-inference will use exllama (faster) kernels whereever possible, and use
|
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||||
/// triton kernel (wider support) when it's not.
|
/// triton kernel (wider support) when it's not.
|
||||||
/// AWQ has faster kernels.
|
/// AWQ has faster kernels.
|
||||||
Gptq,
|
Gptq,
|
||||||
@ -368,6 +368,11 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
|
|
||||||
|
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
|
||||||
|
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
@ -489,6 +494,9 @@ fn shard_manager(
|
|||||||
// Safetensors load fast
|
// Safetensors load fast
|
||||||
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||||
|
|
||||||
|
// Disable progress bar
|
||||||
|
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
||||||
|
|
||||||
// Enable hf transfer for insane download speeds
|
// Enable hf transfer for insane download speeds
|
||||||
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
||||||
envs.push((
|
envs.push((
|
||||||
@ -573,6 +581,13 @@ fn shard_manager(
|
|||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(shard_stdout_reader.lines());
|
log_lines(shard_stdout_reader.lines());
|
||||||
});
|
});
|
||||||
|
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||||
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
thread::spawn(move || {
|
||||||
|
for line in shard_stderr_reader.lines().flatten() {
|
||||||
|
err_sender.send(line).unwrap_or(());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut ready = false;
|
let mut ready = false;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
@ -580,13 +595,6 @@ fn shard_manager(
|
|||||||
loop {
|
loop {
|
||||||
// Process exited
|
// Process exited
|
||||||
if let Some(exit_status) = p.try_wait().unwrap() {
|
if let Some(exit_status) = p.try_wait().unwrap() {
|
||||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
|
||||||
let (err_sender, err_receiver) = mpsc::channel();
|
|
||||||
thread::spawn(move || {
|
|
||||||
for line in shard_stderr_reader.lines().flatten() {
|
|
||||||
err_sender.send(line).unwrap_or(());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let mut err = String::new();
|
let mut err = String::new();
|
||||||
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
|
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
|
||||||
err = err + "\n" + &line;
|
err = err + "\n" + &line;
|
||||||
@ -782,6 +790,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// Disable progress bar
|
||||||
|
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
||||||
|
|
||||||
// If huggingface_hub_cache is set, pass it to the download process
|
// If huggingface_hub_cache is set, pass it to the download process
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
||||||
@ -832,12 +843,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Redirect STDOUT to the console
|
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
||||||
let download_stdout = download_process.stdout.take().unwrap();
|
|
||||||
let stdout = BufReader::new(download_stdout);
|
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(stdout.lines());
|
log_lines(download_stdout.lines());
|
||||||
|
});
|
||||||
|
|
||||||
|
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
||||||
|
|
||||||
|
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||||
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
thread::spawn(move || {
|
||||||
|
for line in download_stderr.lines().flatten() {
|
||||||
|
err_sender.send(line).unwrap_or(());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@ -848,12 +867,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut err = String::new();
|
let mut err = String::new();
|
||||||
download_process
|
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
|
||||||
.stderr
|
err = err + "\n" + &line;
|
||||||
.take()
|
}
|
||||||
.unwrap()
|
|
||||||
.read_to_string(&mut err)
|
|
||||||
.unwrap();
|
|
||||||
if let Some(signal) = status.signal() {
|
if let Some(signal) = status.signal() {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
"Download process was signaled to shutdown with signal {signal}: {err}"
|
"Download process was signaled to shutdown with signal {signal}: {err}"
|
||||||
@ -965,7 +982,20 @@ fn spawn_shards(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn compute_type(num_shard: usize) -> Option<String> {
|
||||||
|
let output = Command::new("nvidia-smi")
|
||||||
|
.args(["--query-gpu=gpu_name", "--format=csv"])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
let output = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let fullname = output.split('\n').nth(1)?;
|
||||||
|
let cardname = fullname.replace(' ', "-").to_lowercase();
|
||||||
|
let compute_type = format!("{num_shard}-{cardname}");
|
||||||
|
Some(compute_type)
|
||||||
|
}
|
||||||
|
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
|
num_shard: usize,
|
||||||
args: Args,
|
args: Args,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
@ -1004,6 +1034,12 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Tokenizer config path
|
||||||
|
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
||||||
|
router_args.push("--tokenizer-config-path".to_string());
|
||||||
|
router_args.push(tokenizer_config_path.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional max batch total tokens
|
// Model optional max batch total tokens
|
||||||
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
router_args.push("--max-batch-total-tokens".to_string());
|
router_args.push("--max-batch-total-tokens".to_string());
|
||||||
@ -1049,6 +1085,13 @@ fn spawn_webserver(
|
|||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Parse Compute type
|
||||||
|
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
||||||
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
|
} else if let Some(compute_type) = compute_type(num_shard) {
|
||||||
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
|
}
|
||||||
|
|
||||||
let mut webserver = match Command::new("text-generation-router")
|
let mut webserver = match Command::new("text-generation-router")
|
||||||
.args(router_args)
|
.args(router_args)
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
@ -1242,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut webserver =
|
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||||
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
.map_err(|err| {
|
||||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] }
|
|||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
tokenizers = { version = "0.15.1", features = ["http"] }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||||
|
@ -165,6 +165,28 @@ impl Infer {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokenizer the input
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn tokenize(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
|
// Tokenize request
|
||||||
|
let inputs = request.inputs;
|
||||||
|
let truncate = request.parameters.truncate;
|
||||||
|
let encoding = self
|
||||||
|
.validation
|
||||||
|
.tokenize(inputs, truncate)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("Tokenization {err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Return Encoding
|
||||||
|
Ok(encoding.map(|(encoding, _)| encoding))
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
|
@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &str) -> Self {
|
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).unwrap();
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
}
|
}
|
||||||
@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletion {
|
pub(crate) struct ChatCompletion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
#[schema(example = "1706270835")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionComplete>,
|
pub choices: Vec<ChatCompletionComplete>,
|
||||||
pub usage: Usage,
|
pub usage: Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionComplete {
|
pub(crate) struct ChatCompletionComplete {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: Message,
|
pub message: Message,
|
||||||
@ -248,17 +250,19 @@ impl ChatCompletion {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChunk {
|
pub(crate) struct ChatCompletionChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
#[schema(example = "1706270978")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChoice {
|
pub(crate) struct ChatCompletionChoice {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatCompletionDelta,
|
pub delta: ChatCompletionDelta,
|
||||||
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
|
|||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
#[schema(example = "bigscience/blomm-560m")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
pub model: String, /* NOTE: UNUSED */
|
pub model: String, /* NOTE: UNUSED */
|
||||||
|
|
||||||
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
|
|||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
|
|||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
/// output token returned in the content of message.
|
/// output token returned in the content of message.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "false")]
|
||||||
pub logprobs: Option<bool>,
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "5")]
|
||||||
pub top_logprobs: Option<u32>,
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "2")]
|
||||||
pub n: Option<u32>,
|
pub n: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
/// increasing the model's likelihood to talk about new topics
|
/// increasing the model's likelihood to talk about new topics
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.1)]
|
||||||
pub presence_penalty: Option<f32>,
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
#[serde(default = "bool::default")]
|
#[serde(default = "bool::default")]
|
||||||
@ -365,6 +377,20 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
#[schema(nullable = true, example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
|
/// lower values like 0.2 will make it more focused and deterministic.
|
||||||
|
///
|
||||||
|
/// We generally recommend altering this or `top_p` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
@ -432,8 +458,21 @@ pub struct Token {
|
|||||||
special: bool,
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct SimpleToken {
|
||||||
|
#[schema(example = 0)]
|
||||||
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
|
text: String,
|
||||||
|
#[schema(example = 0)]
|
||||||
|
start: usize,
|
||||||
|
#[schema(example = 2)]
|
||||||
|
stop: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
|
#[schema(example = "Length")]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
#[schema(rename = "length")]
|
#[schema(rename = "length")]
|
||||||
Length,
|
Length,
|
||||||
@ -494,6 +533,10 @@ pub(crate) struct GenerateResponse {
|
|||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamDetails {
|
pub(crate) struct StreamDetails {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
@ -524,26 +567,12 @@ pub(crate) struct ErrorResponse {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::io::Write;
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||||
let filename = std::path::Path::new("tokenizer.json");
|
let api = hf_hub::api::sync::Api::new().unwrap();
|
||||||
if !filename.exists() {
|
let repo = api.model("gpt2".to_string());
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
let filename = repo.get("tokenizer.json").unwrap();
|
||||||
.await
|
Tokenizer::from_file(filename).unwrap()
|
||||||
.unwrap()
|
|
||||||
.bytes()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let tmp_filename = "tokenizer.json.temp";
|
|
||||||
let mut file = std::fs::File::create(tmp_filename).unwrap();
|
|
||||||
file.write_all(&content).unwrap();
|
|
||||||
// Re-check if another process has written this file maybe.
|
|
||||||
if !filename.exists() {
|
|
||||||
std::fs::rename(tmp_filename, filename).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
chat_enabled_api: bool,
|
messages_api_enabled: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
chat_enabled_api,
|
messages_api_enabled,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
// Load tokenizer config
|
|
||||||
// This will be used to format the chat template
|
|
||||||
let local_tokenizer_config_path =
|
|
||||||
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
|
||||||
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
let mut builder = ApiBuilder::new()
|
let mut builder = ApiBuilder::new()
|
||||||
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||||
let tokenizer_config = if local_tokenizer_config {
|
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||||
|
tracing::info!("Using local tokenizer config from user specified path");
|
||||||
|
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
||||||
|
} else if local_model {
|
||||||
tracing::info!("Using local tokenizer config");
|
tracing::info!("Using local tokenizer config");
|
||||||
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
||||||
} else if let Some(api) = api {
|
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
|
||||||
get_tokenizer_config(&api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.unwrap_or_else(|| "main".to_string()),
|
|
||||||
)))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
match api {
|
||||||
HubTokenizerConfig::default()
|
Some(api) => {
|
||||||
|
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or("main".to_string()),
|
||||||
|
);
|
||||||
|
get_tokenizer_config(&api.repo(repo))
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
tracing::warn!(
|
||||||
|
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||||
|
);
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
@ -348,7 +353,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
chat_enabled_api,
|
messages_api_enabled,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -462,7 +467,12 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
|||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
|
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
||||||
|
e
|
||||||
|
})
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
Some(tokenizer_config)
|
Some(tokenizer_config)
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,10 @@ use crate::health::Health;
|
|||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
||||||
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
|
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||||
Token, Validation,
|
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
@ -66,11 +67,11 @@ async fn compat_generate(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer, Json(req.into()))
|
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
@ -145,6 +146,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
|
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
@ -230,7 +232,7 @@ async fn generate(
|
|||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-time",
|
"x-compute-time",
|
||||||
total_time.as_millis().to_string().parse().unwrap(),
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
@ -339,6 +341,7 @@ seed,
|
|||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
@ -349,13 +352,14 @@ async fn generate_stream(
|
|||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream_internal(
|
async fn generate_stream_internal(
|
||||||
infer: Infer,
|
infer: Infer,
|
||||||
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
@ -368,7 +372,7 @@ async fn generate_stream_internal(
|
|||||||
let compute_characters = req.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-characters",
|
"x-compute-characters",
|
||||||
compute_characters.to_string().parse().unwrap(),
|
compute_characters.to_string().parse().unwrap(),
|
||||||
@ -532,7 +536,7 @@ async fn generate_stream_internal(
|
|||||||
path = "/v1/chat/completions",
|
path = "/v1/chat/completions",
|
||||||
request_body = ChatRequest,
|
request_body = ChatRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
@ -557,6 +561,7 @@ async fn generate_stream_internal(
|
|||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
@ -592,10 +597,10 @@ async fn chat_completions(
|
|||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: None,
|
temperature: req.temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
@ -604,7 +609,7 @@ async fn chat_completions(
|
|||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: true,
|
details: true,
|
||||||
decoder_input_details: true,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
},
|
},
|
||||||
@ -644,13 +649,22 @@ async fn chat_completions(
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
|
infer,
|
||||||
|
compute_type,
|
||||||
|
Json(generate_request),
|
||||||
|
on_message_callback,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) =
|
let (headers, Json(generation)) = generate(
|
||||||
generate(Extension(infer), Json(generate_request)).await?;
|
Extension(infer),
|
||||||
|
Extension(compute_type),
|
||||||
|
Json(generate_request),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -672,6 +686,52 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokenize inputs
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/tokenize",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
|
||||||
|
(status = 404, description = "No tokenizer found", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "No fast tokenizer available"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn tokenize(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Json(req): Json<GenerateRequest>,
|
||||||
|
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let input = req.inputs.clone();
|
||||||
|
let encoding = infer.tokenize(req).await?;
|
||||||
|
if let Some(encoding) = encoding {
|
||||||
|
let tokens: Vec<SimpleToken> = encoding
|
||||||
|
.get_ids()
|
||||||
|
.iter()
|
||||||
|
.zip(encoding.get_offsets())
|
||||||
|
.map(|(&id, &(start, stop))| {
|
||||||
|
let text: String = input.chars().skip(start).take(stop - start).collect();
|
||||||
|
SimpleToken {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
start,
|
||||||
|
stop,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(Json(TokenizeResponse(tokens)))
|
||||||
|
} else {
|
||||||
|
Err((
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||||
|
error_type: "no fast tokenizer".to_string(),
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
@ -683,6 +743,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct ComputeType(String);
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
@ -708,7 +771,7 @@ pub async fn run(
|
|||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
chat_enabled_api: bool,
|
messages_api_enabled: bool,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -719,6 +782,8 @@ pub async fn run(
|
|||||||
compat_generate,
|
compat_generate,
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
|
chat_completions,
|
||||||
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
@ -726,10 +791,18 @@ pub async fn run(
|
|||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
|
ChatRequest,
|
||||||
|
Message,
|
||||||
|
ChatCompletionChoice,
|
||||||
|
ChatCompletionDelta,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletion,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
|
TokenizeResponse,
|
||||||
|
SimpleToken,
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Details,
|
Details,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
@ -863,21 +936,26 @@ pub async fn run(
|
|||||||
// Define base and health routes
|
// Define base and health routes
|
||||||
let base_routes = Router::new()
|
let base_routes = Router::new()
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
|
.route("/", get(health))
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics));
|
.route("/metrics", get(metrics));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
// Conditional AWS Sagemaker route
|
||||||
let aws_sagemaker_route = if chat_enabled_api {
|
let aws_sagemaker_route = if messages_api_enabled {
|
||||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
||||||
} else {
|
} else {
|
||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let compute_type =
|
||||||
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
// Combine routes and layers
|
// Combine routes and layers
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
@ -887,6 +965,7 @@ pub async fn run(
|
|||||||
.layer(Extension(health_ext.clone()))
|
.layer(Extension(health_ext.clone()))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
|
.layer(Extension(compute_type))
|
||||||
.layer(Extension(prom_handle.clone()))
|
.layer(Extension(prom_handle.clone()))
|
||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
@ -70,12 +70,11 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self, inputs))]
|
#[instrument(skip(self, inputs))]
|
||||||
async fn validate_input(
|
pub async fn tokenize(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
||||||
) -> Result<(String, usize, u32), ValidationError> {
|
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -88,7 +87,24 @@ impl Validation {
|
|||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
let encoding = response_receiver.await.unwrap()?;
|
||||||
|
Ok(Some(encoding))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self, inputs))]
|
||||||
|
async fn validate_input(
|
||||||
|
&self,
|
||||||
|
inputs: String,
|
||||||
|
truncate: Option<usize>,
|
||||||
|
max_new_tokens: Option<u32>,
|
||||||
|
) -> Result<(String, usize, u32), ValidationError> {
|
||||||
|
// If we have a fast tokenizer
|
||||||
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
|
// Create response channel
|
||||||
|
let input_length = encoding.len();
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
@ -343,36 +359,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<
|
|||||||
|
|
||||||
/// Get input length and optionally truncate it
|
/// Get input length and optionally truncate it
|
||||||
fn prepare_input(
|
fn prepare_input(
|
||||||
inputs: String,
|
mut inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
) -> Result<(String, usize), ValidationError> {
|
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let mut encoding = tokenizer
|
let mut encoding = tokenizer
|
||||||
.encode(inputs.clone(), true)
|
.encode(inputs.clone(), true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
// Optionally truncate
|
// Optionally truncate
|
||||||
let (inputs, input_length) = match truncate {
|
if let Some(truncate) = truncate {
|
||||||
// Truncate is some and < encoding length
|
if truncate < encoding.len() {
|
||||||
Some(truncate) if truncate < encoding.len() => {
|
|
||||||
// truncate encoding and decode new inputs
|
|
||||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
let inputs = tokenizer
|
inputs = tokenizer
|
||||||
.decode(encoding.get_ids(), false)
|
.decode(encoding.get_ids(), false)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
(inputs, encoding.len())
|
|
||||||
}
|
}
|
||||||
// Nothing to do
|
}
|
||||||
_ => (inputs, encoding.len()),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((inputs, input_length))
|
Ok((encoding, inputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<(String, usize), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
|
eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9
|
||||||
|
|
||||||
eetq:
|
eetq:
|
||||||
# Clone eetq
|
# Clone eetq
|
||||||
@ -6,7 +6,7 @@ eetq:
|
|||||||
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
|
||||||
|
|
||||||
build-eetq: eetq
|
build-eetq: eetq
|
||||||
cd eetq && git fetch && git checkout $(eetq_commit)
|
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
|
||||||
cd eetq && python setup.py build
|
cd eetq && python setup.py build
|
||||||
|
|
||||||
install-eetq: build-eetq
|
install-eetq: build-eetq
|
||||||
|
@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
|||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
#if __CUDA_ARCH__ < 700
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < 600
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -2,8 +2,11 @@
|
|||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
#include "../util.cuh"
|
#include "../util.cuh"
|
||||||
#include "../matrix.cuh"
|
#include "../matrix.cuh"
|
||||||
#include "../cuda_compat.cuh"
|
#include "../cu_compat.cuh"
|
||||||
#include "../cuda_buffers.cuh"
|
#include "../cuda_buffers.cuh"
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
#include "../hip_compat.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||||
@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
|
|||||||
|
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half result = __hadd(acc.x, acc.y);
|
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -1,12 +1,23 @@
|
|||||||
#ifndef _compat_gemm_cuh
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
#define _compat_gemm_cuh
|
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#ifndef _hip_compat_cuh
|
||||||
|
#define _hip_compat_cuh
|
||||||
|
|
||||||
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
|
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
|
||||||
// for symbols as hipblasHalf.
|
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||||
#include <hipblas/hipblas.h>
|
return __half_raw{
|
||||||
|
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||||
|
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||||
|
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||||
|
}
|
||||||
|
|
||||||
|
#define hrcp __compat_hrcp
|
||||||
|
#define h2rcp __compat_h2rcp
|
||||||
|
|
||||||
|
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
|
||||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||||
hipblasOperation_t transA,
|
hipblasOperation_t transA,
|
||||||
hipblasOperation_t transB,
|
hipblasOperation_t transB,
|
||||||
@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
|
|||||||
#define hipblasHgemm __compat_hipblasHgemm
|
#define hipblasHgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||||
|
#define rocblas_handle hipblasHandle_t
|
||||||
#define rocblas_operation_none HIPBLAS_OP_N
|
#define rocblas_operation_none HIPBLAS_OP_N
|
||||||
|
#define rocblas_get_stream hipblasGetStream
|
||||||
|
#define rocblas_set_stream hipblasSetStream
|
||||||
#define rocblas_hgemm __compat_hipblasHgemm
|
#define rocblas_hgemm __compat_hipblasHgemm
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -8,7 +8,11 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
#define cudaUnspecified hipErrorUnknown
|
||||||
|
#else
|
||||||
#define cudaUnspecified cudaErrorApiFailureBase
|
#define cudaUnspecified cudaErrorApiFailureBase
|
||||||
|
#endif
|
||||||
|
|
||||||
// React to failure on return code != cudaSuccess
|
// React to failure on return code != cudaSuccess
|
||||||
|
|
||||||
|
@ -1,5 +1,15 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
import torch
|
||||||
|
|
||||||
|
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||||
|
|
||||||
|
if torch.version.hip:
|
||||||
|
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
||||||
|
|
||||||
|
extra_compile_args = {
|
||||||
|
"nvcc": extra_cuda_cflags,
|
||||||
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="exllamav2_kernels",
|
name="exllamav2_kernels",
|
||||||
@ -11,6 +21,7 @@ setup(
|
|||||||
"exllamav2_kernels/cuda/q_matrix.cu",
|
"exllamav2_kernels/cuda/q_matrix.cu",
|
||||||
"exllamav2_kernels/cuda/q_gemm.cu",
|
"exllamav2_kernels/cuda/q_gemm.cu",
|
||||||
],
|
],
|
||||||
|
extra_compile_args=extra_compile_args,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
1193
server/poetry.lock
generated
1193
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "1.3.4"
|
version = "1.4.0"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.15.0"
|
tokenizers = "^0.15.0"
|
||||||
huggingface-hub = "^0.19.3"
|
huggingface-hub = "^0.19.3"
|
||||||
transformers = "^4.36.1"
|
transformers = "^4.37.1"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
|||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
|||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
77
server/tests/utils/test_layers.py
Normal file
77
server/tests/utils/test_layers.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroup:
|
||||||
|
def __init__(self, rank: int, world_size: int):
|
||||||
|
self._rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
return self.world_size
|
||||||
|
|
||||||
|
def rank(self) -> int:
|
||||||
|
return self._rank
|
||||||
|
|
||||||
|
|
||||||
|
class Weights:
|
||||||
|
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
|
||||||
|
self.weight = (
|
||||||
|
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
|
||||||
|
)
|
||||||
|
self.process_group = ProcessGroup(rank, world_size)
|
||||||
|
|
||||||
|
def get_partial_sharded(self, name: str, dim: int):
|
||||||
|
assert dim == 0
|
||||||
|
|
||||||
|
rank = self.process_group.rank()
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
size = self.weight.shape[dim]
|
||||||
|
|
||||||
|
block_size = (size + world_size - 1) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
return self.weight[start:stop]
|
||||||
|
|
||||||
|
def get_shape(self, name: str):
|
||||||
|
return self.weight.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_offline_error():
|
||||||
|
|
||||||
|
vocab_size = 17
|
||||||
|
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
|
||||||
|
embeddings = TensorParallelEmbedding("", weights)
|
||||||
|
|
||||||
|
input_ids = torch.arange(vocab_size)
|
||||||
|
output = embeddings.forward(input_ids)
|
||||||
|
assert embeddings.min_id == 0
|
||||||
|
assert embeddings.max_id == 17
|
||||||
|
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
|
||||||
|
|
||||||
|
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
||||||
|
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
||||||
|
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
|
||||||
|
assert embeddings_0_2.min_id == 0
|
||||||
|
assert embeddings_0_2.max_id == 9
|
||||||
|
torch.testing.assert_close(
|
||||||
|
embeddings_0_2.weight,
|
||||||
|
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
|
||||||
|
.view(10, 256)
|
||||||
|
.float(),
|
||||||
|
)
|
||||||
|
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
|
||||||
|
assert embeddings_1_2.min_id == 9
|
||||||
|
assert embeddings_1_2.max_id == 17
|
||||||
|
torch.testing.assert_close(
|
||||||
|
embeddings_1_2.weight,
|
||||||
|
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
|
||||||
|
.view(9, 256)
|
||||||
|
.float(),
|
||||||
|
)
|
||||||
|
output_tp_0 = embeddings_0_2.forward(input_ids)
|
||||||
|
output_tp_1 = embeddings_1_2.forward(input_ids)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
|
@ -50,19 +50,39 @@ def test_batch_top_tokens():
|
|||||||
top_n_tokens = [0, 2, 3, 4, 5]
|
top_n_tokens = [0, 2, 3, 4, 5]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||||
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
||||||
|
accepted_ids = torch.ones_like(top_n_tokens_tensor)
|
||||||
|
|
||||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
assert topn_tok_ids[0] == []
|
assert topn_tok_ids[0] == [[]]
|
||||||
assert topn_tok_ids[1] == [0, 3]
|
assert topn_tok_ids[1] == [[0, 3]]
|
||||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||||
|
|
||||||
assert topn_tok_logprobs[0] == []
|
assert topn_tok_logprobs[0] == [[]]
|
||||||
assert topn_tok_logprobs[1] == [-1, -2]
|
assert topn_tok_logprobs[1] == [[-1, -2]]
|
||||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||||
|
|
||||||
|
# Now let's make second member of the batch be speculated
|
||||||
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
|
||||||
|
accepted_ids[1] = 2
|
||||||
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
assert topn_tok_ids[0] == [[]]
|
||||||
|
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
|
||||||
|
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
|
||||||
|
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
|
||||||
|
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
|
||||||
|
|
||||||
|
assert topn_tok_logprobs[0] == [[]]
|
||||||
|
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
|
||||||
|
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
|
||||||
|
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
|
||||||
|
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.models.santacoder import SantaCoder
|
|||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.mamba import Mamba
|
from text_generation_server.models.mamba import Mamba
|
||||||
|
from text_generation_server.models.phi import Phi
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -58,6 +59,7 @@ try:
|
|||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -73,6 +75,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
|
__all__.append(FlashPhi)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -247,6 +250,39 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif model_type == "phi":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashPhi(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "phi-msft":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Legacy phi-msft is not supported with Flash Attention"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Phi(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
elif model_type == "llama" or model_type == "baichuan":
|
elif model_type == "llama" or model_type == "baichuan":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
|
@ -580,10 +580,13 @@ class CausalLM(Model):
|
|||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
|
# Speculation is not active for causal
|
||||||
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
batch.top_n_tokens_tensor,
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
@ -692,20 +695,24 @@ class CausalLM(Model):
|
|||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
all_top_tokens = []
|
||||||
top_token_ids,
|
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||||
clean_up_tokenization_spaces=False,
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
skip_special_tokens=False,
|
top_token_ids,
|
||||||
)
|
clean_up_tokenization_spaces=False,
|
||||||
special_toptokens = [
|
skip_special_tokens=False,
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
)
|
||||||
]
|
special_toptokens = [
|
||||||
top_tokens = Tokens(
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
top_token_ids,
|
]
|
||||||
top_token_logprobs,
|
top_tokens = Tokens(
|
||||||
toptoken_texts,
|
top_token_ids,
|
||||||
special_toptokens,
|
top_token_logprobs,
|
||||||
)
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
all_top_tokens.append(top_tokens)
|
||||||
|
top_tokens = all_top_tokens
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
|
@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
self.rotary_dim = int(config.rotary_pct * self.head_size)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
config=config,
|
||||||
|
dim=self.rotary_dim,
|
||||||
|
base=config.rotary_emb_base,
|
||||||
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
@ -0,0 +1,410 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PhiConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51200,
|
||||||
|
hidden_size=2560,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
hidden_act="gelu_fast", # llama uses silu
|
||||||
|
layer_norm_eps=1e-05, # rms in llama,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
resid_pdrop=0.1, # llama doesn't have this
|
||||||
|
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.partial_rotary_factor = partial_rotary_factor
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# this is the same as llama except for Phi uses bias=True
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
# this is the same as llama except for Phi uses bias=True
|
||||||
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=True, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.rotary_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
# in llama the dense layer is called "o_proj" and has bias=False
|
||||||
|
self.dense = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
):
|
||||||
|
# Compute query, key, value and split
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape query and key for rotary embeddings
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
# NOTE: this is the main difference between Llama and Phi
|
||||||
|
# in llama the rotary embeddings are applied to the whole query and key.
|
||||||
|
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||||
|
#
|
||||||
|
# Apply partial positional embeddings in place
|
||||||
|
self.rotary_emb(
|
||||||
|
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape key and value and cache
|
||||||
|
paged_attention.reshape_and_cache(
|
||||||
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
torch.select(kv, dim=1, index=1),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class PhiMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# llama weights are up_proj and down_proj and bias=False
|
||||||
|
self.up_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.fc1",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.fc2",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# NOTE: Llama requires the gate up states to an intermediate size
|
||||||
|
# Phi does not and we can avoid the `view` operation
|
||||||
|
return self.down_proj(self.act(self.up_proj(hidden_states)))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
self.self_attn = FlashPhiAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
):
|
||||||
|
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.resid_dropout(attn_output).add(
|
||||||
|
self.resid_dropout(self.mlp(hidden_states))
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashPhiLayer(
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
self.norm = FastLayerNorm.load(
|
||||||
|
prefix="model.final_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = FlashPhiModel(config, weights)
|
||||||
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
|
return self.lm_head(hidden_states)
|
@ -28,7 +28,6 @@ EPS = 1e-5
|
|||||||
|
|
||||||
|
|
||||||
def load_col(config, prefix, weights, bias):
|
def load_col(config, prefix, weights, bias):
|
||||||
assert bias == False, NotImplementedError
|
|
||||||
assert config.quantize != "gptq", NotImplementedError
|
assert config.quantize != "gptq", NotImplementedError
|
||||||
slice_ = weights._get_slice(f"{prefix}.weight")
|
slice_ = weights._get_slice(f"{prefix}.weight")
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
@ -45,7 +44,36 @@ def load_col(config, prefix, weights, bias):
|
|||||||
if weight.dtype != torch.int32:
|
if weight.dtype != torch.int32:
|
||||||
weight = weight.to(dtype=weights.dtype)
|
weight = weight.to(dtype=weights.dtype)
|
||||||
weight = weight.to(device=weights.device)
|
weight = weight.to(device=weights.device)
|
||||||
bias = None
|
|
||||||
|
if bias:
|
||||||
|
bias_slice_ = weights._get_slice(f"{prefix}.bias")
|
||||||
|
bias_rank = weights.process_group.rank()
|
||||||
|
bias_size = weights.process_group.size()
|
||||||
|
|
||||||
|
bias_h = bias_slice_.get_shape()
|
||||||
|
bias_h = bias_h[0]
|
||||||
|
bias_block_size = bias_h // bias_size
|
||||||
|
|
||||||
|
bias_q_part = bias_slice_[
|
||||||
|
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
|
||||||
|
]
|
||||||
|
bias_k_part = bias_slice_[
|
||||||
|
bias_h
|
||||||
|
+ bias_rank * bias_block_size : bias_h
|
||||||
|
+ (bias_rank + 1) * bias_block_size
|
||||||
|
]
|
||||||
|
bias_v_part = bias_slice_[
|
||||||
|
2 * bias_h
|
||||||
|
+ bias_rank * bias_block_size : 2 * bias_h
|
||||||
|
+ (bias_rank + 1) * bias_block_size
|
||||||
|
]
|
||||||
|
|
||||||
|
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
||||||
|
if bias.dtype != torch.int32:
|
||||||
|
bias = bias.to(dtype=weights.dtype)
|
||||||
|
bias = bias.to(device=weights.device)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
return TensorParallelColumnLinear(linear)
|
return TensorParallelColumnLinear(linear)
|
||||||
|
|
||||||
@ -330,7 +358,16 @@ class MultiheadAttention(nn.Module):
|
|||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
raise NotImplementedError("qk_ln is not supported")
|
bias = not config.no_bias
|
||||||
|
hidden_size = config.d_model
|
||||||
|
head_dim = hidden_size // self.n_heads
|
||||||
|
|
||||||
|
self.q_ln = LPLayerNorm(
|
||||||
|
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
|
||||||
|
)
|
||||||
|
self.k_ln = LPLayerNorm(
|
||||||
|
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
|
||||||
|
)
|
||||||
if self.attn_impl == "flash":
|
if self.attn_impl == "flash":
|
||||||
self.attn_fn = flash_attn_fn
|
self.attn_fn = flash_attn_fn
|
||||||
elif self.attn_impl == "triton":
|
elif self.attn_impl == "triton":
|
||||||
@ -581,12 +618,20 @@ class MPTBlock(nn.Module):
|
|||||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
||||||
)
|
)
|
||||||
resid_pdrop = config.resid_pdrop
|
resid_pdrop = config.resid_pdrop
|
||||||
self.norm_1 = nn.LayerNorm.load_no_bias(
|
if config.no_bias:
|
||||||
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
self.norm_1 = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
||||||
self.norm_2 = nn.LayerNorm.load_no_bias(
|
)
|
||||||
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
self.norm_2 = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
self.norm_2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
||||||
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
||||||
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
||||||
@ -635,6 +680,9 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
bias: Optional[bool] = True,
|
||||||
|
prefix=None,
|
||||||
|
weights=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
normalized_shape=normalized_shape,
|
normalized_shape=normalized_shape,
|
||||||
@ -642,7 +690,13 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
if weights is not None:
|
||||||
|
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
||||||
|
self.normalized_shape = self.weight.shape
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
module_device = x.device
|
module_device = x.device
|
||||||
@ -755,20 +809,23 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
||||||
|
|
||||||
if not self.alibi:
|
if not self.alibi:
|
||||||
# self.wpe = torch.nn.Embedding(
|
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
||||||
# config.max_seq_len, config.d_model, device=config.init_device
|
|
||||||
# )
|
|
||||||
raise RuntimeError("no alibi no supported")
|
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
||||||
for i in range(config.n_layers)
|
for i in range(config.n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_f = nn.LayerNorm.load_no_bias(
|
if config.no_bias:
|
||||||
prefix="transformer.norm_f", weights=weights, eps=EPS
|
self.norm_f = nn.LayerNorm.load_no_bias(
|
||||||
)
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_f = nn.LayerNorm.load(
|
||||||
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
self.is_causal = not self.prefix_lm
|
self.is_causal = not self.prefix_lm
|
||||||
self._attn_bias_initialized = False
|
self._attn_bias_initialized = False
|
||||||
self.attn_bias = None
|
self.attn_bias = None
|
||||||
@ -787,8 +844,9 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
if config.verbose:
|
if config.verbose:
|
||||||
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
||||||
module.register_parameter("bias", None)
|
module.register_parameter("bias", None)
|
||||||
if config.verbose and config.verbose > 2:
|
if hasattr(self.config, "verbose"):
|
||||||
print(self)
|
if config.verbose and config.verbose > 2:
|
||||||
|
print(self)
|
||||||
if "verbose" not in self.config.init_config:
|
if "verbose" not in self.config.init_config:
|
||||||
self.config.init_config["verbose"] = self.config.verbose
|
self.config.init_config["verbose"] = self.config.verbose
|
||||||
if self.config.init_config["verbose"] > 1:
|
if self.config.init_config["verbose"] > 1:
|
||||||
|
@ -0,0 +1,330 @@
|
|||||||
|
# imlementation of the PhiModel and PhiForCausalLM classes
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, List, Tuple, Any
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# PhiConfig is the configuration class for the PhiModel.
|
||||||
|
class PhiConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51200,
|
||||||
|
n_positions=2048,
|
||||||
|
n_embd=2560,
|
||||||
|
n_layer=32,
|
||||||
|
n_inner=None,
|
||||||
|
n_head=32,
|
||||||
|
rotary_dim=32,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
pad_vocab_size_multiple=64,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
no_bias=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_inner = n_inner
|
||||||
|
self.n_head = n_head
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.no_bias = no_bias
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# RotaryEmbedding is a class that implements the rotary embedding.
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
|
||||||
|
inv_freq_len = len(inv_freq)
|
||||||
|
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
||||||
|
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
||||||
|
freqs = t.matmul(inv_freq)
|
||||||
|
self.sin = freqs.sin()
|
||||||
|
self.cos = freqs.cos()
|
||||||
|
|
||||||
|
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
|
||||||
|
b_size, seqlen, three, _, _headdim = qkv.shape
|
||||||
|
if three != 3:
|
||||||
|
raise Exception("unexpected shape for qkv")
|
||||||
|
_, rotary_dim = self.cos.shape
|
||||||
|
rotary_dim = rotary_dim * 2
|
||||||
|
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||||
|
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||||
|
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
||||||
|
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
||||||
|
q12 = torch.chunk(q_rot, 2, dim=-1)
|
||||||
|
k12 = torch.chunk(k_rot, 2, dim=-1)
|
||||||
|
q1, q2 = q12[0], q12[1]
|
||||||
|
k1, k2 = k12[0], k12[1]
|
||||||
|
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
||||||
|
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
||||||
|
q_rot = torch.cat(
|
||||||
|
[
|
||||||
|
q1 * c - q2 * s,
|
||||||
|
q1 * s + q2 * c,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
k_rot = torch.cat(
|
||||||
|
[
|
||||||
|
k1 * c - k2 * s,
|
||||||
|
k1 * s + k2 * c,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
q = torch.cat([q_rot, q_pass], dim=-1)
|
||||||
|
k = torch.cat([k_rot, k_pass], dim=-1)
|
||||||
|
v = qkv[:, :, 2]
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
|
||||||
|
class PhiCausalLMHead(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm.load(
|
||||||
|
prefix="lm_head.ln",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
self.linear = TensorParallelHead.load(
|
||||||
|
config=config, prefix="lm_head.linear", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.ln(hidden_states)
|
||||||
|
hidden_states = self.linear(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
||||||
|
class PhiMHA(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.Wqkv = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=not config.no_bias,
|
||||||
|
)
|
||||||
|
self.op_size = config.n_embd
|
||||||
|
self.head_dim = int(config.n_embd / config.n_head)
|
||||||
|
self.num_heads = config.n_head
|
||||||
|
self.rotary_emb = RotaryEmbedding(
|
||||||
|
config.rotary_dim,
|
||||||
|
config.n_positions,
|
||||||
|
)
|
||||||
|
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
past_kv_cache,
|
||||||
|
attention_mask=None,
|
||||||
|
):
|
||||||
|
b_size, seq_len, _n_embd = hidden_states.shape
|
||||||
|
qkv = self.Wqkv(hidden_states)
|
||||||
|
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
|
||||||
|
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
|
||||||
|
|
||||||
|
# if there is a kv_cache, then we need to concatenate
|
||||||
|
if past_kv_cache is not None:
|
||||||
|
prev_k, prev_v = past_kv_cache
|
||||||
|
k = torch.cat([prev_k, k], dim=1)
|
||||||
|
v = torch.cat([prev_v, v], dim=1)
|
||||||
|
|
||||||
|
past_kv_cache = [k, v]
|
||||||
|
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
seqlen_k = k.shape[1]
|
||||||
|
seqlen_q = q.shape[1]
|
||||||
|
causal_mask = torch.triu(
|
||||||
|
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
||||||
|
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
||||||
|
attn_output = (
|
||||||
|
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
|
||||||
|
.transpose(1, 2)
|
||||||
|
.flatten(-2)
|
||||||
|
)
|
||||||
|
return self.out_proj(attn_output), past_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
||||||
|
class PhiMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_inner = config.n_inner
|
||||||
|
self.fc1 = FastLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.fc1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.fc2 = FastLinear.load(
|
||||||
|
config=config,
|
||||||
|
prefix=f"{prefix}.fc2",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.activation = torch.nn.functional.gelu
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
||||||
|
class PhiBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.layer_norm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||||
|
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
kv_cache,
|
||||||
|
attention_mask,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
attn_outputs, past_kv_cache = self.mixer(
|
||||||
|
hidden_states, kv_cache, attention_mask
|
||||||
|
)
|
||||||
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||||
|
out = attn_outputs + feed_forward_hidden_states + residual
|
||||||
|
return out, past_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
# PhiModel implements the embedding layer and the transformer blocks.
|
||||||
|
class PhiModel(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_rank = weights.process_group.rank()
|
||||||
|
self.tp_world_size = weights.process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="transformer.embd.wte", weights=weights
|
||||||
|
)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
||||||
|
for layer_id in range(config.n_layer)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.ByteTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
seq_len = hidden_states.shape[1]
|
||||||
|
mask = None if seq_len <= 1 else attention_mask
|
||||||
|
|
||||||
|
past_key_values = (
|
||||||
|
[None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, block in enumerate(self.blocks):
|
||||||
|
hidden_states, new_key_values = block(
|
||||||
|
hidden_states, past_key_values[index], mask
|
||||||
|
)
|
||||||
|
past_key_values[index] = new_key_values
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||||
|
class PhiForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.model = PhiModel(config, weights)
|
||||||
|
self.lm_head = PhiCausalLMHead(config, weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.ByteTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
||||||
|
)
|
||||||
|
logits = self.lm_head(model_output[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = nn.CrossEntropyLoss()(
|
||||||
|
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (
|
||||||
|
((loss,) + (logits,) + model_output[1:])
|
||||||
|
if loss is not None
|
||||||
|
else (logits,) + model_output[1:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=model_output[1],
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
@ -842,6 +842,8 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
(
|
(
|
||||||
next_input_ids,
|
next_input_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
@ -851,16 +853,15 @@ class FlashCausalLM(Model):
|
|||||||
) = batch.next_token_chooser(
|
) = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
get_speculate(),
|
speculate,
|
||||||
batch.speculative_ids,
|
batch.speculative_ids,
|
||||||
speculative_logits,
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -1062,20 +1063,24 @@ class FlashCausalLM(Model):
|
|||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
all_top_tokens = []
|
||||||
top_token_ids,
|
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||||
clean_up_tokenization_spaces=False,
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
skip_special_tokens=False,
|
top_token_ids,
|
||||||
)
|
clean_up_tokenization_spaces=False,
|
||||||
special_toptokens = [
|
skip_special_tokens=False,
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
)
|
||||||
]
|
special_toptokens = [
|
||||||
top_tokens = Tokens(
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
top_token_ids,
|
]
|
||||||
top_token_logprobs,
|
top_tokens = Tokens(
|
||||||
toptoken_texts,
|
top_token_ids,
|
||||||
special_toptokens,
|
top_token_logprobs,
|
||||||
)
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
all_top_tokens.append(top_tokens)
|
||||||
|
top_tokens = all_top_tokens
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
|
@ -74,9 +74,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
is_local_model = (
|
||||||
"WEIGHTS_CACHE_OVERRIDE", None
|
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||||
) is not None
|
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
|
102
server/text_generation_server/models/flash_phi.py
Normal file
102
server/text_generation_server/models/flash_phi.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
|
PhiConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhi(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PhiConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
model = FlashPhiForCausalLM(config, weights)
|
||||||
|
if use_medusa:
|
||||||
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
is_local_model = (
|
||||||
|
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||||
|
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||||
|
|
||||||
|
if not is_local_model:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
medusa_head = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||||
|
weights = Weights(
|
||||||
|
[medusa_sf], device, dtype, process_group=self.process_group
|
||||||
|
)
|
||||||
|
lm_head = model.lm_head
|
||||||
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashPhi, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_layers=len(model.model.layers),
|
||||||
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
66
server/text_generation_server/models/phi.py
Normal file
66
server/text_generation_server/models/phi.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||||
|
PhiConfig,
|
||||||
|
PhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config = PhiConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = config.bos_token_id
|
||||||
|
tokenizer.eos_token_id = config.eos_token_id
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
model = PhiForCausalLM(config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Speculation is not active for seq2seq
|
||||||
|
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
batch.top_n_tokens_tensor,
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
@ -746,20 +749,24 @@ class Seq2SeqLM(Model):
|
|||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
all_top_tokens = []
|
||||||
top_token_ids,
|
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
|
||||||
clean_up_tokenization_spaces=False,
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
skip_special_tokens=False,
|
top_token_ids,
|
||||||
)
|
clean_up_tokenization_spaces=False,
|
||||||
special_toptokens = [
|
skip_special_tokens=False,
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
)
|
||||||
]
|
special_toptokens = [
|
||||||
top_tokens = Tokens(
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
top_token_ids,
|
]
|
||||||
top_token_logprobs,
|
top_tokens = Tokens(
|
||||||
toptoken_texts,
|
top_token_ids,
|
||||||
special_toptokens,
|
top_token_logprobs,
|
||||||
)
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
all_top_tokens.append(top_tokens)
|
||||||
|
top_tokens = all_top_tokens
|
||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
|
@ -95,5 +95,5 @@ class Generation:
|
|||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||||
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
from loguru import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
@ -185,6 +182,10 @@ class QuantLinear(nn.Module):
|
|||||||
"g_idx": self.g_idx,
|
"g_idx": self.g_idx,
|
||||||
}
|
}
|
||||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||||
|
|
||||||
|
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
||||||
|
# and `Memory access fault by GPU node-2` will EAT you.
|
||||||
|
self.temp_dq = temp_dq
|
||||||
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
||||||
|
|
||||||
def forward(self, x, force_cuda=False):
|
def forward(self, x, force_cuda=False):
|
||||||
|
@ -33,14 +33,14 @@ except Exception:
|
|||||||
major = 1
|
major = 1
|
||||||
|
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
V2 = False
|
# V2 = False
|
||||||
log_once(
|
# log_once(
|
||||||
logger.warning,
|
# logger.warning,
|
||||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
||||||
)
|
# )
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
@ -507,10 +507,12 @@ class TensorParallelEmbedding(nn.Module):
|
|||||||
world_size = process_group.size()
|
world_size = process_group.size()
|
||||||
rank = process_group.rank()
|
rank = process_group.rank()
|
||||||
|
|
||||||
block_size = num_embeddings // world_size
|
block_size = (num_embeddings + world_size - 1) // world_size
|
||||||
self.min_id = rank * block_size
|
self.min_id = rank * block_size
|
||||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||||
self.null_idx = block_size
|
self.null_idx = weight.shape[
|
||||||
|
0
|
||||||
|
] # Usually block_size, might be less in non even vocab_size.
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.reduce = reduce
|
self.reduce = reduce
|
||||||
|
|
||||||
|
@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B * S)
|
next_ids = next_ids.view(B * S)
|
||||||
scores = scores.view(B * S, -1)
|
allscores = scores.view(B * S, -1)
|
||||||
|
alllogprobs = torch.log_softmax(allscores, -1)
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
@ -305,16 +306,17 @@ class HeterogeneousNextTokenChooser:
|
|||||||
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||||
)
|
)
|
||||||
next_ids = next_ids[indices]
|
next_ids = next_ids[indices]
|
||||||
scores = scores[indices]
|
logprobs = alllogprobs[indices]
|
||||||
indices = torch.arange(B, device=input_ids.device) * S
|
indices = torch.arange(B, device=input_ids.device) * S
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
else:
|
else:
|
||||||
accepted_ids = torch.ones_like(next_ids)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
logprobs = alllogprobs
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
|
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
# Medusa provided some scores
|
# Medusa provided some scores
|
||||||
@ -327,7 +329,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -436,8 +438,8 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
|
|
||||||
def batch_top_tokens(
|
def batch_top_tokens(
|
||||||
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
|
||||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||||
"""Find the top n most likely tokens for a batch of generations.
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
When multiple tokens have equal probabilities and they don't all fit, the
|
When multiple tokens have equal probabilities and they don't all fit, the
|
||||||
@ -446,14 +448,19 @@ def batch_top_tokens(
|
|||||||
max_top_n = max(top_n_tokens)
|
max_top_n = max(top_n_tokens)
|
||||||
# Early exit when top_n_tokens is not used
|
# Early exit when top_n_tokens is not used
|
||||||
if max_top_n == 0:
|
if max_top_n == 0:
|
||||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
batch_size = accepted_ids.shape[0]
|
||||||
|
speculate_size = logprobs.shape[0] // batch_size
|
||||||
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
||||||
# Ensure top_n doesn't exceed vocab size
|
# Ensure top_n doesn't exceed vocab size
|
||||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
|
||||||
|
|
||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
|
||||||
|
|
||||||
nth_highest = torch.gather(
|
nth_highest = torch.gather(
|
||||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||||
)
|
)
|
||||||
@ -471,13 +478,33 @@ def batch_top_tokens(
|
|||||||
top_indices = top_k.indices.tolist()
|
top_indices = top_k.indices.tolist()
|
||||||
top_values = top_k.values.tolist()
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
return (
|
batch_top_token_ids = []
|
||||||
[
|
batch_top_token_logprobs = []
|
||||||
idxs[:n] if req_n > 0 else []
|
accepted_ids_list = accepted_ids.tolist()
|
||||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
||||||
],
|
start = speculate_size * i
|
||||||
[
|
stop = speculate_size * (i + 1)
|
||||||
vals[:n] if req_n > 0 else []
|
_top_indices = top_indices[start: stop]
|
||||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
_top_values = top_values[start: stop]
|
||||||
],
|
_top_n_ishes = top_n_ishes[start: stop]
|
||||||
)
|
_top_n_tokens = top_n_tokens[start: stop]
|
||||||
|
|
||||||
|
_top_indices = _top_indices[:n_accepted_ids]
|
||||||
|
_top_values = _top_values[:n_accepted_ids]
|
||||||
|
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
||||||
|
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
|
||||||
|
|
||||||
|
row_top_token_ids = []
|
||||||
|
row_top_token_logprobs = []
|
||||||
|
|
||||||
|
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
|
||||||
|
indices = idxs[:n] if req_n > 0 else []
|
||||||
|
values = vals[:n] if req_n > 0 else []
|
||||||
|
|
||||||
|
row_top_token_ids.append(indices)
|
||||||
|
row_top_token_logprobs.append(values)
|
||||||
|
|
||||||
|
batch_top_token_ids.append(row_top_token_ids)
|
||||||
|
batch_top_token_logprobs.append(row_top_token_logprobs)
|
||||||
|
|
||||||
|
return batch_top_token_ids, batch_top_token_logprobs
|
||||||
|
@ -92,7 +92,7 @@ class Weights:
|
|||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
size = slice_.get_shape()[dim]
|
size = slice_.get_shape()[dim]
|
||||||
block_size = size // world_size
|
block_size = (size + world_size - 1) // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user