Merge branch 'main' into impl-simple-mamba-model

This commit is contained in:
drbh 2024-02-06 18:38:20 -05:00 committed by GitHub
commit 9146ba00a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 3876 additions and 1197 deletions

View File

@ -1,12 +0,0 @@
name: Delete doc comment
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}

10
.gitignore vendored
View File

@ -2,3 +2,13 @@
target target
router/tokenizer.json router/tokenizer.json
*__pycache__* *__pycache__*
# ROCm auto-generated files
*.hip
server/exllamav2_kernels/exllamav2_kernels/hip/
server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp

420
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "1.3.4" version = "1.4.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \ mamba init && \
rm ~/mambaforge.sh rm ~/mambaforge.sh
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. # Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
FROM base AS kernel-builder FROM base AS kernel-builder
@ -104,6 +104,20 @@ WORKDIR /usr/src
COPY server/custom_kernels/ . COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
FROM base as base-copy FROM base as base-copy
# Text Generation Inference base env # Text Generation Inference base env
@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir

View File

@ -28,7 +28,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
- [Local Install](#local-install) - [Local Install](#local-install)
- [CUDA Kernels](#cuda-kernels) - [CUDA Kernels](#cuda-kernels)
- [Optimized architectures](#optimized-architectures) - [Optimized architectures](#optimized-architectures)
- [Run Falcon](#run-falcon) - [Run Mistral](#run-a-model)
- [Run](#run) - [Run](#run)
- [Quantization](#quantization) - [Quantization](#quantization)
- [Develop](#develop) - [Develop](#develop)
@ -42,7 +42,11 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Token streaming using Server-Sent Events (SSE) - Token streaming using Server-Sent Events (SSE)
- Continuous batching of incoming requests for increased total throughput - Continuous batching of incoming requests for increased total throughput
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) - Quantization with :
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [GPT-Q](https://arxiv.org/abs/2210.17323)
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
@ -51,6 +55,14 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output - Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance - Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
### Hardware support
- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
- [Gaudi](https://github.com/huggingface/tgi-gaudi)
## Get Started ## Get Started
@ -62,7 +74,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
model=HuggingFaceH4/zephyr-7b-beta model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -76,7 +88,7 @@ curl 127.0.0.1:8080/generate \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -106,7 +118,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
@ -154,7 +166,7 @@ Python 3.9, e.g. using `conda`:
```shell ```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
conda create -n text-generation-inference python=3.9 conda create -n text-generation-inference python=3.11
conda activate text-generation-inference conda activate text-generation-inference
``` ```
@ -180,7 +192,7 @@ Then run:
```shell ```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
make run-falcon-7b-instruct text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
``` ```
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: **Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
@ -189,16 +201,9 @@ make run-falcon-7b-instruct
sudo apt-get install libssl-dev gcc -y sudo apt-get install libssl-dev gcc -y
``` ```
### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default.
## Optimized architectures ## Optimized architectures
TGI works out of the box to serve optimized models in [this list](https://huggingface.co/docs/text-generation-inference/supported_models). TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
Other architectures are supported on a best-effort basis using: Other architectures are supported on a best-effort basis using:
@ -210,12 +215,12 @@ or
## Run Falcon ## Run locally
### Run ### Run
```shell ```shell
make run-falcon-7b-instruct text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
``` ```
### Quantization ### Quantization
@ -223,7 +228,7 @@ make run-falcon-7b-instruct
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell ```shell
make run-falcon-7b-instruct-quantize text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
``` ```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.3.4" "version": "1.4.0"
}, },
"paths": { "paths": {
"/": { "/": {
@ -342,6 +342,135 @@
} }
} }
} }
},
"/tokenize": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Tokenized ids",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TokenizeResponse"
}
}
}
},
"404": {
"description": "No tokenizer found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "No fast tokenizer available"
}
}
}
}
}
}
},
"/v1/chat/completions": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionChunk"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
} }
}, },
"components": { "components": {
@ -399,6 +528,226 @@
} }
} }
}, },
"ChatCompletion": {
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
"choices",
"usage"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionComplete"
}
},
"created": {
"type": "integer",
"format": "int64",
"example": "1706270835",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
},
"usage": {
"$ref": "#/components/schemas/Usage"
}
}
},
"ChatCompletionChoice": {
"type": "object",
"required": [
"index",
"delta"
],
"properties": {
"delta": {
"$ref": "#/components/schemas/ChatCompletionDelta"
},
"finish_reason": {
"type": "string",
"nullable": true
},
"index": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"logprobs": {
"type": "number",
"format": "float",
"nullable": true
}
}
},
"ChatCompletionChunk": {
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
"choices"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionChoice"
}
},
"created": {
"type": "integer",
"format": "int64",
"example": "1706270978",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
}
},
"ChatCompletionDelta": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"ChatRequest": {
"type": "object",
"required": [
"model"
],
"properties": {
"frequency_penalty": {
"type": "number",
"format": "float",
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"example": "1.0",
"nullable": true
},
"logit_bias": {
"type": "array",
"items": {
"type": "number",
"format": "float"
},
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"nullable": true
},
"logprobs": {
"type": "boolean",
"description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.",
"example": "false",
"nullable": true
},
"max_tokens": {
"type": "integer",
"format": "int32",
"description": "The maximum number of tokens that can be generated in the chat completion.",
"example": "32",
"nullable": true,
"minimum": 0
},
"messages": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
},
"description": "A list of messages comprising the conversation so far."
},
"model": {
"type": "string",
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"n": {
"type": "integer",
"format": "int32",
"description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.",
"example": "2",
"nullable": true,
"minimum": 0
},
"presence_penalty": {
"type": "number",
"format": "float",
"description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
"example": 0.1,
"nullable": true
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true,
"minimum": 0
},
"stream": {
"type": "boolean"
},
"temperature": {
"type": "number",
"format": "float",
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.",
"example": 1.0,
"nullable": true
},
"top_logprobs": {
"type": "integer",
"format": "int32",
"description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
"example": "5",
"nullable": true,
"minimum": 0
},
"top_p": {
"type": "number",
"format": "float",
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
}
}
},
"CompatGenerateRequest": { "CompatGenerateRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -494,7 +843,8 @@
"length", "length",
"eos_token", "eos_token",
"stop_sequence" "stop_sequence"
] ],
"example": "Length"
}, },
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
@ -523,7 +873,7 @@
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"default": "20", "default": "100",
"example": "20", "example": "20",
"nullable": true, "nullable": true,
"minimum": 0 "minimum": 0
@ -758,6 +1108,23 @@
} }
} }
}, },
"Message": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"PrefillToken": { "PrefillToken": {
"type": "object", "type": "object",
"required": [ "required": [
@ -784,6 +1151,37 @@
} }
} }
}, },
"SimpleToken": {
"type": "object",
"required": [
"id",
"text",
"start",
"stop"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0,
"minimum": 0
},
"start": {
"type": "integer",
"example": 0,
"minimum": 0
},
"stop": {
"type": "integer",
"example": 2,
"minimum": 0
},
"text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": { "StreamDetails": {
"type": "object", "type": "object",
"required": [ "required": [
@ -812,6 +1210,7 @@
"StreamResponse": { "StreamResponse": {
"type": "object", "type": "object",
"required": [ "required": [
"index",
"token" "token"
], ],
"properties": { "properties": {
@ -830,6 +1229,11 @@
"example": "test", "example": "test",
"nullable": true "nullable": true
}, },
"index": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"token": { "token": {
"$ref": "#/components/schemas/Token" "$ref": "#/components/schemas/Token"
}, },
@ -871,6 +1275,12 @@
"example": "test" "example": "test"
} }
} }
},
"TokenizeResponse": {
"type": "array",
"items": {
"$ref": "#/components/schemas/SimpleToken"
}
} }
} }
}, },

View File

@ -7,6 +7,8 @@
title: Installation title: Installation
- local: supported_models - local: supported_models
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api
title: Messages API
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \ -e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
--model-id $model --model-id $model
``` ```

View File

@ -60,9 +60,9 @@ Options:
[env: QUANTIZE=] [env: QUANTIZE=]
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
@ -354,6 +354,14 @@ Options:
[env: NGROK_EDGE=] [env: NGROK_EDGE=]
```
## TOKENIZER_CONFIG_PATH
```shell
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
The path to the tokenizer config file. This path is used to load the tokenizer configuration which may include a `chat_template`. If not provided, the default config will be used from the model hub
[env: TOKENIZER_CONFIG_PATH=]
``` ```
## ENV ## ENV
```shell ```shell

View File

@ -1,6 +1,6 @@
# Using TGI CLI # Using TGI CLI
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli). You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).
`text-generation-server` lets you download the model with `download-weights` command like below 👇 `text-generation-server` lets you download the model with `download-weights` command like below 👇

175
docs/source/messages_api.md Normal file
View File

@ -0,0 +1,175 @@
# Messages API
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
#### Table of Contents
- [Making a Request](#making-a-request)
- [Streaming](#streaming)
- [Synchronous](#synchronous)
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
- [Cloud Providers](#cloud-providers)
- [Amazon SageMaker](#amazon-sagemaker)
## Making a Request
You can make a request to TGI's Messages API using `curl`. Here's an example:
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
## Streaming
You can also use OpenAI's Python client library to make a streaming request. Here's how:
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="-"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
```
## Synchronous
If you prefer to make a synchronous request, you can do so like this:
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="-"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
```
## Hugging Face Inference Endpoints
The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).
Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:
> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
# replace with your endpoint url, make sure to include "v1/" at the end
base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/",
# replace with your API key
api_key="hf_XXX"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message.choices[0].delta.content, end="")
```
## Cloud Providers
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
```python
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
container_startup_health_check_timeout=300,
)
# send request
predictor.predict({
"messages": [
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
]
})
```

View File

@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
``` ```
<Tip warning={true}> <Tip warning={true}>
@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead: TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
```bash ```bash
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
``` ```
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:1.3 --help docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
``` ```
</Tip> </Tip>

View File

@ -19,7 +19,9 @@ The following models are optimized and can be served with TGI, which uses custom
- [MPT](https://huggingface.co/mosaicml/mpt-30b) - [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama) - [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama) - [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) - [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Phi](https://huggingface.co/microsoft/phi-2)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
@ -41,8 +43,8 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Quantization (GPTQ, AWQ, etc.) * Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm) * Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
* Kernel for slinding window attention (Mistral) * Kernel for slinding window attention (Mistral)

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.76660156,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7246094,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.41333008,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.11785889,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.97265625,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.0569458,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}

View File

@ -0,0 +1,60 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 6,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 284,
"logprob": -0.19421387,
"special": false,
"text": " to"
},
{
"id": 3758,
"logprob": -0.62597656,
"special": false,
"text": " send"
},
{
"id": 1366,
"logprob": -0.87060547,
"special": false,
"text": " data"
},
{
"id": 625,
"logprob": -0.88427734,
"special": false,
"text": " over"
},
{
"id": 257,
"logprob": -1.0830078,
"special": false,
"text": " a"
},
{
"id": 3127,
"logprob": -1.9462891,
"special": false,
"text": " network"
}
],
"top_tokens": null
},
"generated_text": "Test request to send data over a network"
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}
]

View File

@ -16,52 +16,52 @@
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -9.09375, "logprob": -9.0859375,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.25976562, "logprob": -0.25830078,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -2.2148438, "logprob": -2.1875,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.3010254, "logprob": -0.30004883,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.6757812, "logprob": -5.6171875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.0898438, "logprob": -3.078125,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -0.6791992, "logprob": -0.68066406,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.38891602, "logprob": -0.38745117,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.92041016, "logprob": -0.9453125,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.5390625, "logprob": -2.5371094,
"text": "]):" "text": "]):"
} }
], ],
@ -69,7 +69,7 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": 0.0, "logprob": -0.051635742,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
@ -81,7 +81,7 @@
}, },
{ {
"id": 11665, "id": 11665,
"logprob": -1.6005859, "logprob": -1.2236328,
"special": false, "special": false,
"text": " reduce" "text": " reduce"
}, },
@ -159,7 +159,7 @@
}, },
{ {
"id": 203, "id": 203,
"logprob": -0.11968994, "logprob": -0.12695312,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },

View File

@ -11,92 +11,92 @@
}, },
{ {
"id": 4911, "id": 4911,
"logprob": -5.7851562, "logprob": -6.9765625,
"text": "User" "text": "User"
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -0.006996155, "logprob": -0.0059432983,
"text": ":" "text": ":"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -0.81347656, "logprob": -0.8408203,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 32001, "id": 32001,
"logprob": -6.687641e-05, "logprob": -9.906292e-05,
"text": "<image>" "text": "<image>"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -3.5762787e-07, "logprob": -2.3841858e-07,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 1815, "id": 1815,
"logprob": -4.2148438, "logprob": -4.1679688,
"text": "Can" "text": "Can"
}, },
{ {
"id": 366, "id": 366,
"logprob": -0.014137268, "logprob": -0.014099121,
"text": "you" "text": "you"
}, },
{ {
"id": 2649, "id": 2649,
"logprob": -4.4335938, "logprob": -4.4609375,
"text": "tell" "text": "tell"
}, },
{ {
"id": 592, "id": 592,
"logprob": -0.2919922, "logprob": -0.29882812,
"text": "me" "text": "me"
}, },
{ {
"id": 263, "id": 263,
"logprob": -4.2070312, "logprob": -4.1445312,
"text": "a" "text": "a"
}, },
{ {
"id": 1407, "id": 1407,
"logprob": -9.421875, "logprob": -9.3828125,
"text": "very" "text": "very"
}, },
{ {
"id": 3273, "id": 3273,
"logprob": -1.8720703, "logprob": -1.9736328,
"text": "short" "text": "short"
}, },
{ {
"id": 5828, "id": 5828,
"logprob": -0.26489258, "logprob": -0.2800293,
"text": "story" "text": "story"
}, },
{ {
"id": 2729, "id": 2729,
"logprob": -3.7441406, "logprob": -3.5625,
"text": "based" "text": "based"
}, },
{ {
"id": 373, "id": 373,
"logprob": -0.0005393028, "logprob": -0.0006427765,
"text": "on" "text": "on"
}, },
{ {
"id": 278, "id": 278,
"logprob": -0.140625, "logprob": -0.13952637,
"text": "the" "text": "the"
}, },
{ {
"id": 1967, "id": 1967,
"logprob": -0.06756592, "logprob": -0.068115234,
"text": "image" "text": "image"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -0.15454102, "logprob": -0.16357422,
"text": "?" "text": "?"
} }
], ],
@ -104,25 +104,25 @@
"tokens": [ "tokens": [
{ {
"id": 32002, "id": 32002,
"logprob": -0.0019140244, "logprob": -0.0026474,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },
{ {
"id": 29871, "id": 29871,
"logprob": -8.404255e-05, "logprob": -8.547306e-05,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 13, "id": 13,
"logprob": -1.7642975e-05, "logprob": -1.7881393e-05,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 7900, "id": 7900,
"logprob": -2.9802322e-06, "logprob": -3.0994415e-06,
"special": false, "special": false,
"text": "Ass" "text": "Ass"
}, },
@ -140,30 +140,29 @@
}, },
{ {
"id": 319, "id": 319,
"logprob": -0.91064453, "logprob": -0.92529297,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 696, "id": 696,
"logprob": -1.2412109, "logprob": -1.1269531,
"special": false, "special": false,
"text": " ro" "text": " ro"
}, },
{ {
"id": 15664, "id": 15664,
"logprob": -0.0002439022, "logprob": -0.00029492378,
"special": false, "special": false,
"text": "oster" "text": "oster"
}, },
{ {
"id": 15028, "id": 15028,
"logprob": -1.1630859, "logprob": -1.1855469,
"special": false, "special": false,
"text": " stands" "text": " stands"
} }
], ]
"top_tokens": null
}, },
"generated_text": " \nAssistant: A rooster stands" "generated_text": " \nAssistant: A rooster stands"
} }

View File

@ -12,92 +12,92 @@
}, },
{ {
"id": 4911, "id": 4911,
"logprob": -5.7851562, "logprob": -6.9804688,
"text": "User" "text": "User"
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -0.006996155, "logprob": -0.006122589,
"text": ":" "text": ":"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -0.81347656, "logprob": -0.8417969,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 32001, "id": 32001,
"logprob": -6.687641e-05, "logprob": -9.918213e-05,
"text": "<image>" "text": "<image>"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -3.5762787e-07, "logprob": -2.3841858e-07,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 1815, "id": 1815,
"logprob": -4.2148438, "logprob": -4.1679688,
"text": "Can" "text": "Can"
}, },
{ {
"id": 366, "id": 366,
"logprob": -0.014137268, "logprob": -0.014091492,
"text": "you" "text": "you"
}, },
{ {
"id": 2649, "id": 2649,
"logprob": -4.4335938, "logprob": -4.4726562,
"text": "tell" "text": "tell"
}, },
{ {
"id": 592, "id": 592,
"logprob": -0.2919922, "logprob": -0.2998047,
"text": "me" "text": "me"
}, },
{ {
"id": 263, "id": 263,
"logprob": -4.2070312, "logprob": -4.15625,
"text": "a" "text": "a"
}, },
{ {
"id": 1407, "id": 1407,
"logprob": -9.421875, "logprob": -9.3828125,
"text": "very" "text": "very"
}, },
{ {
"id": 3273, "id": 3273,
"logprob": -1.8720703, "logprob": -1.9716797,
"text": "short" "text": "short"
}, },
{ {
"id": 5828, "id": 5828,
"logprob": -0.26489258, "logprob": -0.27734375,
"text": "story" "text": "story"
}, },
{ {
"id": 2729, "id": 2729,
"logprob": -3.7441406, "logprob": -3.5605469,
"text": "based" "text": "based"
}, },
{ {
"id": 373, "id": 373,
"logprob": -0.0005393028, "logprob": -0.00064468384,
"text": "on" "text": "on"
}, },
{ {
"id": 278, "id": 278,
"logprob": -0.140625, "logprob": -0.14160156,
"text": "the" "text": "the"
}, },
{ {
"id": 1967, "id": 1967,
"logprob": -0.06756592, "logprob": -0.06915283,
"text": "image" "text": "image"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -0.15454102, "logprob": -0.16381836,
"text": "?" "text": "?"
} }
], ],
@ -105,19 +105,19 @@
"tokens": [ "tokens": [
{ {
"id": 32002, "id": 32002,
"logprob": -0.0019140244, "logprob": -0.0026664734,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },
{ {
"id": 29871, "id": 29871,
"logprob": -8.392334e-05, "logprob": -8.583069e-05,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 13, "id": 13,
"logprob": -1.7881393e-05, "logprob": -1.8119812e-05,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -135,36 +135,35 @@
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -3.0994415e-06, "logprob": -3.2186508e-06,
"special": false, "special": false,
"text": ":" "text": ":"
}, },
{ {
"id": 319, "id": 319,
"logprob": -0.9057617, "logprob": -0.9301758,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 696, "id": 696,
"logprob": -1.2294922, "logprob": -1.1279297,
"special": false, "special": false,
"text": " ro" "text": " ro"
}, },
{ {
"id": 15664, "id": 15664,
"logprob": -0.00024533272, "logprob": -0.0002939701,
"special": false, "special": false,
"text": "oster" "text": "oster"
}, },
{ {
"id": 15028, "id": 15028,
"logprob": -1.1640625, "logprob": -1.1865234,
"special": false, "special": false,
"text": " stands" "text": " stands"
} }
], ]
"top_tokens": null
}, },
"generated_text": " \nAssistant: A rooster stands" "generated_text": " \nAssistant: A rooster stands"
}, },
@ -181,92 +180,92 @@
}, },
{ {
"id": 4911, "id": 4911,
"logprob": -5.7773438, "logprob": -6.9804688,
"text": "User" "text": "User"
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -0.0070114136, "logprob": -0.006122589,
"text": ":" "text": ":"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -0.8208008, "logprob": -0.8417969,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 32001, "id": 32001,
"logprob": -6.699562e-05, "logprob": -9.942055e-05,
"text": "<image>" "text": "<image>"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -3.5762787e-07, "logprob": -2.3841858e-07,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 1815, "id": 1815,
"logprob": -4.2265625, "logprob": -4.1679688,
"text": "Can" "text": "Can"
}, },
{ {
"id": 366, "id": 366,
"logprob": -0.014175415, "logprob": -0.014091492,
"text": "you" "text": "you"
}, },
{ {
"id": 2649, "id": 2649,
"logprob": -4.4296875, "logprob": -4.4726562,
"text": "tell" "text": "tell"
}, },
{ {
"id": 592, "id": 592,
"logprob": -0.29516602, "logprob": -0.2998047,
"text": "me" "text": "me"
}, },
{ {
"id": 263, "id": 263,
"logprob": -4.2109375, "logprob": -4.15625,
"text": "a" "text": "a"
}, },
{ {
"id": 1407, "id": 1407,
"logprob": -9.4296875, "logprob": -9.3828125,
"text": "very" "text": "very"
}, },
{ {
"id": 3273, "id": 3273,
"logprob": -1.8720703, "logprob": -1.9716797,
"text": "short" "text": "short"
}, },
{ {
"id": 5828, "id": 5828,
"logprob": -0.26879883, "logprob": -0.27734375,
"text": "story" "text": "story"
}, },
{ {
"id": 2729, "id": 2729,
"logprob": -3.7675781, "logprob": -3.5605469,
"text": "based" "text": "based"
}, },
{ {
"id": 373, "id": 373,
"logprob": -0.0005354881, "logprob": -0.0006451607,
"text": "on" "text": "on"
}, },
{ {
"id": 278, "id": 278,
"logprob": -0.13671875, "logprob": -0.14160156,
"text": "the" "text": "the"
}, },
{ {
"id": 1967, "id": 1967,
"logprob": -0.06719971, "logprob": -0.06915283,
"text": "image" "text": "image"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -0.15551758, "logprob": -0.16381836,
"text": "?" "text": "?"
} }
], ],
@ -274,19 +273,19 @@
"tokens": [ "tokens": [
{ {
"id": 32002, "id": 32002,
"logprob": -0.0019130707, "logprob": -0.0026664734,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },
{ {
"id": 29871, "id": 29871,
"logprob": -8.392334e-05, "logprob": -8.571148e-05,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 13, "id": 13,
"logprob": -1.7881393e-05, "logprob": -1.8119812e-05,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -310,30 +309,29 @@
}, },
{ {
"id": 319, "id": 319,
"logprob": -0.9013672, "logprob": -0.9301758,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 696, "id": 696,
"logprob": -1.2324219, "logprob": -1.1279297,
"special": false, "special": false,
"text": " ro" "text": " ro"
}, },
{ {
"id": 15664, "id": 15664,
"logprob": -0.0002477169, "logprob": -0.0002939701,
"special": false, "special": false,
"text": "oster" "text": "oster"
}, },
{ {
"id": 15028, "id": 15028,
"logprob": -1.1660156, "logprob": -1.1865234,
"special": false, "special": false,
"text": " stands" "text": " stands"
} }
], ]
"top_tokens": null
}, },
"generated_text": " \nAssistant: A rooster stands" "generated_text": " \nAssistant: A rooster stands"
}, },
@ -350,92 +348,92 @@
}, },
{ {
"id": 4911, "id": 4911,
"logprob": -5.7773438, "logprob": -6.9804688,
"text": "User" "text": "User"
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -0.0070114136, "logprob": -0.006122589,
"text": ":" "text": ":"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -0.8208008, "logprob": -0.8417969,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 32001, "id": 32001,
"logprob": -6.699562e-05, "logprob": -9.918213e-05,
"text": "<image>" "text": "<image>"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -3.5762787e-07, "logprob": -2.3841858e-07,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 1815, "id": 1815,
"logprob": -4.2265625, "logprob": -4.1679688,
"text": "Can" "text": "Can"
}, },
{ {
"id": 366, "id": 366,
"logprob": -0.014175415, "logprob": -0.014091492,
"text": "you" "text": "you"
}, },
{ {
"id": 2649, "id": 2649,
"logprob": -4.4296875, "logprob": -4.4726562,
"text": "tell" "text": "tell"
}, },
{ {
"id": 592, "id": 592,
"logprob": -0.29516602, "logprob": -0.2998047,
"text": "me" "text": "me"
}, },
{ {
"id": 263, "id": 263,
"logprob": -4.2109375, "logprob": -4.15625,
"text": "a" "text": "a"
}, },
{ {
"id": 1407, "id": 1407,
"logprob": -9.4296875, "logprob": -9.3828125,
"text": "very" "text": "very"
}, },
{ {
"id": 3273, "id": 3273,
"logprob": -1.8720703, "logprob": -1.9716797,
"text": "short" "text": "short"
}, },
{ {
"id": 5828, "id": 5828,
"logprob": -0.26879883, "logprob": -0.27734375,
"text": "story" "text": "story"
}, },
{ {
"id": 2729, "id": 2729,
"logprob": -3.7675781, "logprob": -3.5605469,
"text": "based" "text": "based"
}, },
{ {
"id": 373, "id": 373,
"logprob": -0.0005354881, "logprob": -0.00064468384,
"text": "on" "text": "on"
}, },
{ {
"id": 278, "id": 278,
"logprob": -0.13671875, "logprob": -0.14160156,
"text": "the" "text": "the"
}, },
{ {
"id": 1967, "id": 1967,
"logprob": -0.06719971, "logprob": -0.06915283,
"text": "image" "text": "image"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -0.15551758, "logprob": -0.16381836,
"text": "?" "text": "?"
} }
], ],
@ -443,19 +441,19 @@
"tokens": [ "tokens": [
{ {
"id": 32002, "id": 32002,
"logprob": -0.001912117, "logprob": -0.0026664734,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },
{ {
"id": 29871, "id": 29871,
"logprob": -8.392334e-05, "logprob": -8.59499e-05,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 13, "id": 13,
"logprob": -1.7762184e-05, "logprob": -1.8119812e-05,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -479,30 +477,29 @@
}, },
{ {
"id": 319, "id": 319,
"logprob": -0.9013672, "logprob": -0.9301758,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 696, "id": 696,
"logprob": -1.2324219, "logprob": -1.1279297,
"special": false, "special": false,
"text": " ro" "text": " ro"
}, },
{ {
"id": 15664, "id": 15664,
"logprob": -0.0002477169, "logprob": -0.0002939701,
"special": false, "special": false,
"text": "oster" "text": "oster"
}, },
{ {
"id": 15028, "id": 15028,
"logprob": -1.1660156, "logprob": -1.1865234,
"special": false, "special": false,
"text": " stands" "text": " stands"
} }
], ]
"top_tokens": null
}, },
"generated_text": " \nAssistant: A rooster stands" "generated_text": " \nAssistant: A rooster stands"
}, },
@ -519,92 +516,92 @@
}, },
{ {
"id": 4911, "id": 4911,
"logprob": -5.7773438, "logprob": -6.9804688,
"text": "User" "text": "User"
}, },
{ {
"id": 29901, "id": 29901,
"logprob": -0.0070114136, "logprob": -0.006122589,
"text": ":" "text": ":"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -0.8208008, "logprob": -0.8417969,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 32001, "id": 32001,
"logprob": -6.699562e-05, "logprob": -9.942055e-05,
"text": "<image>" "text": "<image>"
}, },
{ {
"id": 32000, "id": 32000,
"logprob": -3.5762787e-07, "logprob": -2.3841858e-07,
"text": "<fake_token_around_image>" "text": "<fake_token_around_image>"
}, },
{ {
"id": 1815, "id": 1815,
"logprob": -4.2265625, "logprob": -4.1679688,
"text": "Can" "text": "Can"
}, },
{ {
"id": 366, "id": 366,
"logprob": -0.014175415, "logprob": -0.014091492,
"text": "you" "text": "you"
}, },
{ {
"id": 2649, "id": 2649,
"logprob": -4.4296875, "logprob": -4.4726562,
"text": "tell" "text": "tell"
}, },
{ {
"id": 592, "id": 592,
"logprob": -0.29516602, "logprob": -0.2998047,
"text": "me" "text": "me"
}, },
{ {
"id": 263, "id": 263,
"logprob": -4.2109375, "logprob": -4.15625,
"text": "a" "text": "a"
}, },
{ {
"id": 1407, "id": 1407,
"logprob": -9.4296875, "logprob": -9.3828125,
"text": "very" "text": "very"
}, },
{ {
"id": 3273, "id": 3273,
"logprob": -1.8720703, "logprob": -1.9716797,
"text": "short" "text": "short"
}, },
{ {
"id": 5828, "id": 5828,
"logprob": -0.26879883, "logprob": -0.27734375,
"text": "story" "text": "story"
}, },
{ {
"id": 2729, "id": 2729,
"logprob": -3.7675781, "logprob": -3.5605469,
"text": "based" "text": "based"
}, },
{ {
"id": 373, "id": 373,
"logprob": -0.0005354881, "logprob": -0.0006451607,
"text": "on" "text": "on"
}, },
{ {
"id": 278, "id": 278,
"logprob": -0.13671875, "logprob": -0.14160156,
"text": "the" "text": "the"
}, },
{ {
"id": 1967, "id": 1967,
"logprob": -0.06719971, "logprob": -0.06915283,
"text": "image" "text": "image"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -0.15551758, "logprob": -0.16381836,
"text": "?" "text": "?"
} }
], ],
@ -612,19 +609,19 @@
"tokens": [ "tokens": [
{ {
"id": 32002, "id": 32002,
"logprob": -0.001912117, "logprob": -0.0026664734,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },
{ {
"id": 29871, "id": 29871,
"logprob": -8.392334e-05, "logprob": -8.571148e-05,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 13, "id": 13,
"logprob": -1.7762184e-05, "logprob": -1.8119812e-05,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -648,30 +645,29 @@
}, },
{ {
"id": 319, "id": 319,
"logprob": -0.9013672, "logprob": -0.9301758,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 696, "id": 696,
"logprob": -1.2324219, "logprob": -1.1279297,
"special": false, "special": false,
"text": " ro" "text": " ro"
}, },
{ {
"id": 15664, "id": 15664,
"logprob": -0.0002477169, "logprob": -0.0002939701,
"special": false, "special": false,
"text": "oster" "text": "oster"
}, },
{ {
"id": 15028, "id": 15028,
"logprob": -1.1660156, "logprob": -1.1865234,
"special": false, "special": false,
"text": " stands" "text": " stands"
} }
], ]
"top_tokens": null
}, },
"generated_text": " \nAssistant: A rooster stands" "generated_text": " \nAssistant: A rooster stands"
} }

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi_handle(launcher):
with launcher("microsoft/phi-2", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi(flash_phi_handle):
await flash_phi_handle.health(300)
return flash_phi_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response.generated_text == ': {request}")\n response = self'
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["network"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response.generated_text == "Test request to send data over a network"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ': {request}")\n response = self'
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.3.4" version = "1.4.0"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines, Read}; use std::io::{BufRead, BufReader, Lines};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
@ -21,16 +21,16 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific GTPQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
/// https://hf.co/models?search=awq. /// https://hf.co/models?search=awq.
/// Should replace GPTQ models whereever possible because of the better latency /// Should replace GPTQ models wherever possible because of the better latency
Awq, Awq,
/// 8 bit quantization, doesn't require specific model. /// 8 bit quantization, doesn't require specific model.
/// Should be a drop-in replacement to bitsandbytes with much better performance. /// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git /// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
Eetq, Eetq,
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. /// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
/// text-generation-inference will use exllama (faster) kernels whereever possible, and use /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
/// AWQ has faster kernels. /// AWQ has faster kernels.
Gptq, Gptq,
@ -368,6 +368,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
#[clap(long, env)]
tokenizer_config_path: Option<String>,
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
@ -489,6 +494,9 @@ fn shard_manager(
// Safetensors load fast // Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push(( envs.push((
@ -573,6 +581,13 @@ fn shard_manager(
thread::spawn(move || { thread::spawn(move || {
log_lines(shard_stdout_reader.lines()); log_lines(shard_stdout_reader.lines());
}); });
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();
@ -580,13 +595,6 @@ fn shard_manager(
loop { loop {
// Process exited // Process exited
if let Some(exit_status) = p.try_wait().unwrap() { if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new(); let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line; err = err + "\n" + &line;
@ -782,6 +790,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// If huggingface_hub_cache is set, pass it to the download process // If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
@ -832,12 +843,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
} }
}; };
// Redirect STDOUT to the console let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
thread::spawn(move || { thread::spawn(move || {
log_lines(stdout.lines()); log_lines(download_stdout.lines());
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in download_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
}); });
loop { loop {
@ -848,12 +867,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
} }
let mut err = String::new(); let mut err = String::new();
download_process while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
.stderr err = err + "\n" + &line;
.take() }
.unwrap()
.read_to_string(&mut err)
.unwrap();
if let Some(signal) = status.signal() { if let Some(signal) = status.signal() {
tracing::error!( tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}" "Download process was signaled to shutdown with signal {signal}: {err}"
@ -965,7 +982,20 @@ fn spawn_shards(
Ok(()) Ok(())
} }
fn compute_type(num_shard: usize) -> Option<String> {
let output = Command::new("nvidia-smi")
.args(["--query-gpu=gpu_name", "--format=csv"])
.output()
.ok()?;
let output = String::from_utf8(output.stdout).ok()?;
let fullname = output.split('\n').nth(1)?;
let cardname = fullname.replace(' ', "-").to_lowercase();
let compute_type = format!("{num_shard}-{cardname}");
Some(compute_type)
}
fn spawn_webserver( fn spawn_webserver(
num_shard: usize,
args: Args, args: Args,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1004,6 +1034,12 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string());
router_args.push(tokenizer_config_path.to_string());
}
// Model optional max batch total tokens // Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string()); router_args.push("--max-batch-total-tokens".to_string());
@ -1049,6 +1085,13 @@ fn spawn_webserver(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// Parse Compute type
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
} else if let Some(compute_type) = compute_type(num_shard) {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}
let mut webserver = match Command::new("text-generation-router") let mut webserver = match Command::new("text-generation-router")
.args(router_args) .args(router_args)
.envs(envs) .envs(envs)
@ -1242,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let mut webserver = let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { .map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err err
})?; })?;

View File

@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] } tokenizers = { version = "0.15.1", features = ["http"] }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] } tower-http = { version = "0.4.4", features = ["cors"] }

View File

@ -165,6 +165,28 @@ impl Infer {
)) ))
} }
/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
err
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self { pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).unwrap_or_default()
} }
@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
pub object: String, pub object: String,
#[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>, pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage, pub usage: Usage,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
@ -248,17 +250,19 @@ impl ChatCompletion {
} }
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
pub object: String, pub object: String,
#[schema(example = "1706270978")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>, pub choices: Vec<ChatCompletionChoice>,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String, pub role: String,
#[schema(example = "What is Deep Learning?")]
pub content: String, pub content: String,
} }
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest { pub(crate) struct ChatRequest {
/// UNUSED /// UNUSED
#[schema(example = "bigscience/blomm-560m")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */ pub model: String, /* NOTE: UNUSED */
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim. /// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)] #[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// UNUSED /// UNUSED
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. /// output token returned in the content of message.
#[serde(default)] #[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>, pub logprobs: Option<bool>,
/// UNUSED /// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used. /// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)] #[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>, pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
/// UNUSED /// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the /// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>, pub n: Option<u32>,
/// UNUSED /// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics /// increasing the model's likelihood to talk about new topics
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>, pub presence_penalty: Option<f32>,
#[serde(default = "bool::default")] #[serde(default = "bool::default")]
@ -365,6 +377,20 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -432,8 +458,21 @@ pub struct Token {
special: bool, special: bool,
} }
#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(example = 0)]
start: usize,
#[schema(example = 2)]
stop: usize,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {
#[schema(rename = "length")] #[schema(rename = "length")]
Length, Length,
@ -494,6 +533,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails { pub(crate) struct StreamDetails {
#[schema(example = "length")] #[schema(example = "length")]
@ -524,26 +567,12 @@ pub(crate) struct ErrorResponse {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer { pub(crate) async fn get_tokenizer() -> Tokenizer {
let filename = std::path::Path::new("tokenizer.json"); let api = hf_hub::api::sync::Api::new().unwrap();
if !filename.exists() { let repo = api.model("gpt2".to_string());
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") let filename = repo.get("tokenizer.json").unwrap();
.await Tokenizer::from_file(filename).unwrap()
.unwrap()
.bytes()
.await
.unwrap();
let tmp_filename = "tokenizer.json.temp";
let mut file = std::fs::File::create(tmp_filename).unwrap();
file.write_all(&content).unwrap();
// Re-check if another process has written this file maybe.
if !filename.exists() {
std::fs::rename(tmp_filename, filename).unwrap()
}
}
Tokenizer::from_file("tokenizer.json").unwrap()
} }
} }

View File

@ -72,7 +72,7 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
chat_enabled_api: bool, messages_api_enabled: bool,
} }
#[tokio::main] #[tokio::main]
@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
chat_enabled_api, messages_api_enabled,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir(); let local_model = local_path.exists() && local_path.is_dir();
// Load tokenizer config
// This will be used to format the chat template
let local_tokenizer_config_path =
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
let mut builder = ApiBuilder::new() let mut builder = ApiBuilder::new()
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
}; };
// Load tokenizer config if found locally, or check if we can get it from the API if needed // Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config { let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config"); tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path) HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else if let Some(api) = api {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()),
)))
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig::default()
})
} else { } else {
tracing::warn!("Could not find tokenizer config locally and no revision specified"); match api {
HubTokenizerConfig::default() Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
}
}; };
if tokenizer.is_none() { if tokenizer.is_none() {
@ -348,7 +353,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
chat_enabled_api, messages_api_enabled,
) )
.await?; .await?;
Ok(()) Ok(())
@ -462,7 +467,12 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
let reader = BufReader::new(file); let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?; let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config) Some(tokenizer_config)
} }

View File

@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
Token, Validation, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate( async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Json(mut req): Json<CompatGenerateRequest>, Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag // default return_full_text given the pipeline_tag
@ -66,11 +67,11 @@ async fn compat_generate(
// switch on stream // switch on stream
if req.stream { if req.stream {
Ok(generate_stream(infer, Json(req.into())) Ok(generate_stream(infer, compute_type, Json(req.into()))
.await .await
.into_response()) .into_response())
} else { } else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
@ -145,6 +146,7 @@ seed,
)] )]
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
Extension(ComputeType(compute_type)): Extension<ComputeType>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
@ -230,7 +232,7 @@ async fn generate(
// Headers // Headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert( headers.insert(
"x-compute-time", "x-compute-time",
total_time.as_millis().to_string().parse().unwrap(), total_time.as_millis().to_string().parse().unwrap(),
@ -339,6 +341,7 @@ seed,
)] )]
async fn generate_stream( async fn generate_stream(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> ( ) -> (
HeaderMap, HeaderMap,
@ -349,13 +352,14 @@ async fn generate_stream(
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, Json(req), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
async fn generate_stream_internal( async fn generate_stream_internal(
infer: Infer, infer: Infer,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
@ -368,7 +372,7 @@ async fn generate_stream_internal(
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert( headers.insert(
"x-compute-characters", "x-compute-characters",
compute_characters.to_string().parse().unwrap(), compute_characters.to_string().parse().unwrap(),
@ -532,7 +536,7 @@ async fn generate_stream_internal(
path = "/v1/chat/completions", path = "/v1/chat/completions",
request_body = ChatRequest, request_body = ChatRequest,
responses( responses(
(status = 200, description = "Generated Text", body = GenerateResponse), (status = 200, description = "Generated Text", body = ChatCompletionChunk),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
@ -557,6 +561,7 @@ async fn generate_stream_internal(
)] )]
async fn chat_completions( async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -592,10 +597,10 @@ async fn chat_completions(
inputs: inputs.to_string(), inputs: inputs.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: None, temperature: req.temperature,
repetition_penalty, repetition_penalty,
top_k: None, top_k: None,
top_p: None, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: true, do_sample: true,
max_new_tokens, max_new_tokens,
@ -604,7 +609,7 @@ async fn chat_completions(
truncate: None, truncate: None,
watermark: false, watermark: false,
details: true, details: true,
decoder_input_details: true, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
}, },
@ -644,13 +649,22 @@ async fn chat_completions(
) )
}; };
let (headers, response_stream) = let (headers, response_stream) = generate_stream_internal(
generate_stream_internal(infer, Json(generate_request), on_message_callback).await; infer,
compute_type,
Json(generate_request),
on_message_callback,
)
.await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = let (headers, Json(generation)) = generate(
generate(Extension(infer), Json(generate_request)).await?; Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -672,6 +686,52 @@ async fn chat_completions(
} }
} }
/// Tokenize inputs
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/tokenize",
request_body = GenerateRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
example = json ! ({"error": "No fast tokenizer available"})),
)
)]
#[instrument(skip_all)]
async fn tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text: String = input.chars().skip(start).take(stop - start).collect();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
@ -683,6 +743,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
} }
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
@ -708,7 +771,7 @@ pub async fn run(
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
chat_enabled_api: bool, messages_api_enabled: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -719,6 +782,8 @@ pub async fn run(
compat_generate, compat_generate,
generate, generate,
generate_stream, generate_stream,
chat_completions,
tokenize,
metrics, metrics,
), ),
components( components(
@ -726,10 +791,18 @@ pub async fn run(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
GenerateRequest, GenerateRequest,
ChatRequest,
Message,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletion,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence, BestOfSequence,
Details, Details,
FinishReason, FinishReason,
@ -863,21 +936,26 @@ pub async fn run(
// Define base and health routes // Define base and health routes
let base_routes = Router::new() let base_routes = Router::new()
.route("/", post(compat_generate)) .route("/", post(compat_generate))
.route("/", get(health))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics));
// Conditional AWS Sagemaker route // Conditional AWS Sagemaker route
let aws_sagemaker_route = if chat_enabled_api { let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else { } else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
}; };
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers // Combine routes and layers
let app = Router::new() let app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
@ -887,6 +965,7 @@ pub async fn run(
.layer(Extension(health_ext.clone())) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone())) .layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default()) .layer(OtelAxumLayer::default())
.layer(cors_layer); .layer(cors_layer);

View File

@ -70,12 +70,11 @@ impl Validation {
} }
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( pub async fn tokenize(
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, ) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -88,7 +87,24 @@ impl Validation {
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?; let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
}
#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
let input_length = encoding.len();
// Get total tokens // Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
@ -343,36 +359,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
inputs: String, mut inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> { ) -> Result<(tokenizers::Encoding, String), ValidationError> {
// Get the number of tokens in the input // Get the number of tokens in the input
let mut encoding = tokenizer let mut encoding = tokenizer
.encode(inputs.clone(), true) .encode(inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate // Optionally truncate
let (inputs, input_length) = match truncate { if let Some(truncate) = truncate {
// Truncate is some and < encoding length if truncate < encoding.len() {
Some(truncate) if truncate < encoding.len() => {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left); encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer inputs = tokenizer
.decode(encoding.get_ids(), false) .decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} }
// Nothing to do }
_ => (inputs, encoding.len()),
};
Ok((inputs, input_length)) Ok((encoding, inputs))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
Span, Span,
); );

View File

@ -1,4 +1,4 @@
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1 eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9
eetq: eetq:
# Clone eetq # Clone eetq
@ -6,7 +6,7 @@ eetq:
git clone https://github.com/NetEase-FuXi/EETQ.git eetq git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit) cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
cd eetq && python setup.py build cd eetq && python setup.py build
install-eetq: build-eetq install-eetq: build-eetq

View File

@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
// //
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif #endif

View File

@ -2,8 +2,11 @@
#include "column_remap.cuh" #include "column_remap.cuh"
#include "../util.cuh" #include "../util.cuh"
#include "../matrix.cuh" #include "../matrix.cuh"
#include "../cuda_compat.cuh" #include "../cu_compat.cuh"
#include "../cuda_buffers.cuh" #include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out
@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
if constexpr (use_half2) if constexpr (use_half2)
{ {
half result = __hadd(acc.x, acc.y); half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result); atomicAdd(out_.item_ptr(x_row, w_column), result);
} }
else else

View File

@ -1,12 +1,23 @@
#ifndef _compat_gemm_cuh // Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _compat_gemm_cuh
#if defined(USE_ROCM) #ifndef _hip_compat_cuh
#define _hip_compat_cuh
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required // Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
// for symbols as hipblasHalf. __device__ __forceinline__ __half __compat_hrcp(__half x) {
#include <hipblas/hipblas.h> return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA, hipblasOperation_t transA,
hipblasOperation_t transB, hipblasOperation_t transB,
@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
#define hipblasHgemm __compat_hipblasHgemm #define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N #define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm #define rocblas_hgemm __compat_hipblasHgemm
#endif
#endif #endif

View File

@ -8,7 +8,11 @@
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase #define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess // React to failure on return code != cudaSuccess

View File

@ -1,5 +1,15 @@
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = ["-lineinfo", "-O3"]
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_compile_args = {
"nvcc": extra_cuda_cflags,
}
setup( setup(
name="exllamav2_kernels", name="exllamav2_kernels",
@ -11,6 +21,7 @@ setup(
"exllamav2_kernels/cuda/q_matrix.cu", "exllamav2_kernels/cuda/q_matrix.cu",
"exllamav2_kernels/cuda/q_gemm.cu", "exllamav2_kernels/cuda/q_gemm.cu",
], ],
extra_compile_args=extra_compile_args,
) )
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},

1193
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.3.4" version = "1.4.0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.15.0" tokenizers = "^0.15.0"
huggingface-hub = "^0.19.3" huggingface-hub = "^0.19.3"
transformers = "^4.36.1" transformers = "^4.37.1"
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }

View File

@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,77 @@
import torch
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self) -> int:
return self.world_size
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
world_size = self.process_group.size()
size = self.weight.shape[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
return self.weight[start:stop]
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
output = embeddings.forward(input_ids)
assert embeddings.min_id == 0
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -50,19 +50,39 @@ def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens) top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5) inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
accepted_ids = torch.ones_like(top_n_tokens_tensor)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens( topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
) )
assert topn_tok_ids[0] == [] assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [0, 3] assert topn_tok_ids[1] == [[0, 3]]
assert topn_tok_ids[2] == [0, 3, 1, 4] assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [0, 3, 1, 4] assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2] assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [] assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [-1, -2] assert topn_tok_logprobs[1] == [[-1, -2]]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3] assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3] assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

View File

@ -19,6 +19,7 @@ from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.mamba import Mamba from text_generation_server.models.mamba import Mamba
from text_generation_server.models.phi import Phi
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -58,6 +59,7 @@ try:
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
@ -73,6 +75,7 @@ if FLASH_ATTENTION:
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral) __all__.append(FlashMistral)
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
__all__.append(FlashPhi)
def get_model( def get_model(
@ -247,6 +250,39 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "phi":
if FLASH_ATTENTION:
return FlashPhi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
else:
return Phi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "llama" or model_type == "baichuan": elif model_type == "llama" or model_type == "baichuan":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(

View File

@ -580,10 +580,13 @@ class CausalLM(Model):
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens,
batch.top_n_tokens_tensor, batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
accepted_ids,
) )
start_decode = time.time_ns() start_decode = time.time_ns()
@ -692,20 +695,24 @@ class CausalLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( all_top_tokens = []
top_token_ids, for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
clean_up_tokenization_spaces=False, toptoken_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, top_token_ids,
) clean_up_tokenization_spaces=False,
special_toptokens = [ skip_special_tokens=False,
token_id in self.all_special_ids for token_id in top_token_ids )
] special_toptokens = [
top_tokens = Tokens( token_id in self.all_special_ids for token_id in top_token_ids
top_token_ids, ]
top_token_logprobs, top_tokens = Tokens(
toptoken_texts, top_token_ids,
special_toptokens, top_token_logprobs,
) toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_dim = int(config.rotary_pct * self.head_size)
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, prefix=f"{prefix}.rotary_emb", weights=weights config=config,
dim=self.rotary_dim,
base=config.rotary_emb_base,
device=weights.device,
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -0,0 +1,410 @@
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastLayerNorm,
)
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
hidden_size=2560,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="gelu_fast", # llama uses silu
layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.rotary_dim,
base=config.rope_theta,
device=weights.device,
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=True,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
# Reshape query and key for rotary embeddings
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# NOTE: this is the main difference between Llama and Phi
# in llama the rotary embeddings are applied to the whole query and key.
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=True,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=True,
)
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
return self.down_proj(self.act(self.up_proj(hidden_states)))
class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
return self.lm_head(hidden_states)

View File

@ -28,7 +28,6 @@ EPS = 1e-5
def load_col(config, prefix, weights, bias): def load_col(config, prefix, weights, bias):
assert bias == False, NotImplementedError
assert config.quantize != "gptq", NotImplementedError assert config.quantize != "gptq", NotImplementedError
slice_ = weights._get_slice(f"{prefix}.weight") slice_ = weights._get_slice(f"{prefix}.weight")
rank = weights.process_group.rank() rank = weights.process_group.rank()
@ -45,7 +44,36 @@ def load_col(config, prefix, weights, bias):
if weight.dtype != torch.int32: if weight.dtype != torch.int32:
weight = weight.to(dtype=weights.dtype) weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device) weight = weight.to(device=weights.device)
bias = None
if bias:
bias_slice_ = weights._get_slice(f"{prefix}.bias")
bias_rank = weights.process_group.rank()
bias_size = weights.process_group.size()
bias_h = bias_slice_.get_shape()
bias_h = bias_h[0]
bias_block_size = bias_h // bias_size
bias_q_part = bias_slice_[
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
]
bias_k_part = bias_slice_[
bias_h
+ bias_rank * bias_block_size : bias_h
+ (bias_rank + 1) * bias_block_size
]
bias_v_part = bias_slice_[
2 * bias_h
+ bias_rank * bias_block_size : 2 * bias_h
+ (bias_rank + 1) * bias_block_size
]
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
if bias.dtype != torch.int32:
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
else:
bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias, config.quantize)
return TensorParallelColumnLinear(linear) return TensorParallelColumnLinear(linear)
@ -330,7 +358,16 @@ class MultiheadAttention(nn.Module):
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
) )
if self.qk_ln: if self.qk_ln:
raise NotImplementedError("qk_ln is not supported") bias = not config.no_bias
hidden_size = config.d_model
head_dim = hidden_size // self.n_heads
self.q_ln = LPLayerNorm(
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
)
self.k_ln = LPLayerNorm(
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
)
if self.attn_impl == "flash": if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton": elif self.attn_impl == "triton":
@ -581,12 +618,20 @@ class MPTBlock(nn.Module):
f"""Not implemented attn {config.attn_config["attn_type"]}""" f"""Not implemented attn {config.attn_config["attn_type"]}"""
) )
resid_pdrop = config.resid_pdrop resid_pdrop = config.resid_pdrop
self.norm_1 = nn.LayerNorm.load_no_bias( if config.no_bias:
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS self.norm_1 = nn.LayerNorm.load_no_bias(
) prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
self.norm_2 = nn.LayerNorm.load_no_bias( )
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS self.norm_2 = nn.LayerNorm.load_no_bias(
) prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
else:
self.norm_1 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@ -635,6 +680,9 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=True, elementwise_affine=True,
device=None, device=None,
dtype=None, dtype=None,
bias: Optional[bool] = True,
prefix=None,
weights=None,
): ):
super().__init__( super().__init__(
normalized_shape=normalized_shape, normalized_shape=normalized_shape,
@ -642,7 +690,13 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=elementwise_affine, elementwise_affine=elementwise_affine,
device=device, device=device,
dtype=dtype, dtype=dtype,
bias=bias,
) )
if weights is not None:
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
if bias:
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
self.normalized_shape = self.weight.shape
def forward(self, x): def forward(self, x):
module_device = x.device module_device = x.device
@ -755,20 +809,23 @@ class MPTModel(MPTPreTrainedModel):
) )
self.wte = TensorParallelEmbedding("transformer.wte", weights) self.wte = TensorParallelEmbedding("transformer.wte", weights)
if not self.alibi: if not self.alibi:
# self.wpe = torch.nn.Embedding( self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
# config.max_seq_len, config.d_model, device=config.init_device
# )
raise RuntimeError("no alibi no supported")
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
for i in range(config.n_layers) for i in range(config.n_layers)
] ]
) )
self.norm_f = nn.LayerNorm.load_no_bias( if config.no_bias:
prefix="transformer.norm_f", weights=weights, eps=EPS self.norm_f = nn.LayerNorm.load_no_bias(
) prefix="transformer.norm_f", weights=weights, eps=EPS
)
else:
self.norm_f = nn.LayerNorm.load(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
self.is_causal = not self.prefix_lm self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False self._attn_bias_initialized = False
self.attn_bias = None self.attn_bias = None
@ -787,8 +844,9 @@ class MPTModel(MPTPreTrainedModel):
if config.verbose: if config.verbose:
warnings.warn(f"Removing bias ({module.bias}) from {module}.") warnings.warn(f"Removing bias ({module.bias}) from {module}.")
module.register_parameter("bias", None) module.register_parameter("bias", None)
if config.verbose and config.verbose > 2: if hasattr(self.config, "verbose"):
print(self) if config.verbose and config.verbose > 2:
print(self)
if "verbose" not in self.config.init_config: if "verbose" not in self.config.init_config:
self.config.init_config["verbose"] = self.config.verbose self.config.init_config["verbose"] = self.config.verbose
if self.config.init_config["verbose"] > 1: if self.config.init_config["verbose"] > 1:

View File

@ -0,0 +1,330 @@
# imlementation of the PhiModel and PhiForCausalLM classes
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
)
# PhiConfig is the configuration class for the PhiModel.
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
freqs = t.matmul(inv_freq)
self.sin = freqs.sin()
self.cos = freqs.cos()
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
b_size, seqlen, three, _, _headdim = qkv.shape
if three != 3:
raise Exception("unexpected shape for qkv")
_, rotary_dim = self.cos.shape
rotary_dim = rotary_dim * 2
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q12 = torch.chunk(q_rot, 2, dim=-1)
k12 = torch.chunk(k_rot, 2, dim=-1)
q1, q2 = q12[0], q12[1]
k1, k2 = k12[0], k12[1]
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
q_rot = torch.cat(
[
q1 * c - q2 * s,
q1 * s + q2 * c,
],
dim=-1,
)
k_rot = torch.cat(
[
k1 * c - k2 * s,
k1 * s + k2 * c,
],
dim=-1,
)
q = torch.cat([q_rot, q_pass], dim=-1)
k = torch.cat([k_rot, k_pass], dim=-1)
v = qkv[:, :, 2]
return q, k, v
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
class PhiCausalLMHead(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.ln = nn.LayerNorm.load(
prefix="lm_head.ln",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.linear = TensorParallelHead.load(
config=config, prefix="lm_head.linear", weights=weights
)
def forward(self, hidden_states):
hidden_states = self.ln(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
self.op_size = config.n_embd
self.head_dim = int(config.n_embd / config.n_head)
self.num_heads = config.n_head
self.rotary_emb = RotaryEmbedding(
config.rotary_dim,
config.n_positions,
)
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
b_size, seq_len, _n_embd = hidden_states.shape
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
# if there is a kv_cache, then we need to concatenate
if past_kv_cache is not None:
prev_k, prev_v = past_kv_cache
k = torch.cat([prev_k, k], dim=1)
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
1,
)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = (
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
.transpose(1, 2)
.flatten(-2)
)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.n_inner = config.n_inner
self.fc1 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=False,
)
self.fc2 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=False,
)
self.activation = torch.nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(
hidden_states, kv_cache, attention_mask
)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = (
[None] * len(self.blocks) if past_key_values is None else past_key_values
)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(
hidden_states, past_key_values[index], mask
)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
return (
((loss,) + (logits,) + model_output[1:])
if loss is not None
else (logits,) + model_output[1:]
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output[1],
hidden_states=None,
attentions=None,
)

View File

@ -842,6 +842,8 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
speculate = get_speculate()
( (
next_input_ids, next_input_ids,
next_token_logprobs, next_token_logprobs,
@ -851,16 +853,15 @@ class FlashCausalLM(Model):
) = batch.next_token_chooser( ) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits, next_token_logits,
get_speculate(), speculate,
batch.speculative_ids, batch.speculative_ids,
speculative_logits, speculative_logits,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1062,20 +1063,24 @@ class FlashCausalLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( all_top_tokens = []
top_token_ids, for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
clean_up_tokenization_spaces=False, toptoken_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, top_token_ids,
) clean_up_tokenization_spaces=False,
special_toptokens = [ skip_special_tokens=False,
token_id in self.all_special_ids for token_id in top_token_ids )
] special_toptokens = [
top_tokens = Tokens( token_id in self.all_special_ids for token_id in top_token_ids
top_token_ids, ]
top_token_logprobs, top_tokens = Tokens(
toptoken_texts, top_token_ids,
special_toptokens, top_token_logprobs,
) toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -74,9 +74,9 @@ class FlashLlama(FlashCausalLM):
import os import os
from pathlib import Path from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( is_local_model = (
"WEIGHTS_CACHE_OVERRIDE", None Path(use_medusa).exists() and Path(use_medusa).is_dir()
) is not None ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model: if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(

View File

@ -0,0 +1,102 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
PhiConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -0,0 +1,66 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
batch.past_key_values, batch.past_key_values,
) )
# Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens,
batch.top_n_tokens_tensor, batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
accepted_ids,
) )
start_decode = time.time_ns() start_decode = time.time_ns()
@ -746,20 +749,24 @@ class Seq2SeqLM(Model):
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0: if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode( all_top_tokens = []
top_token_ids, for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
clean_up_tokenization_spaces=False, toptoken_texts = self.tokenizer.batch_decode(
skip_special_tokens=False, top_token_ids,
) clean_up_tokenization_spaces=False,
special_toptokens = [ skip_special_tokens=False,
token_id in self.all_special_ids for token_id in top_token_ids )
] special_toptokens = [
top_tokens = Tokens( token_id in self.all_special_ids for token_id in top_token_ids
top_token_ids, ]
top_token_logprobs, top_tokens = Tokens(
toptoken_texts, top_token_ids,
special_toptokens, top_token_logprobs,
) toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else: else:
top_tokens = None top_tokens = None

View File

@ -95,5 +95,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
) )

View File

@ -1,12 +1,9 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from logging import getLogger
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
logger = getLogger(__name__) from loguru import logger
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
@ -185,6 +182,10 @@ class QuantLinear(nn.Module):
"g_idx": self.g_idx, "g_idx": self.g_idx,
} }
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
def forward(self, x, force_cuda=False): def forward(self, x, force_cuda=False):

View File

@ -33,14 +33,14 @@ except Exception:
major = 1 major = 1
HAS_EXLLAMA = False HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: # if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
V2 = False # V2 = False
log_once( # log_once(
logger.warning, # logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding", # "Disabling exllama v2 and using v1 instead because there are issues when sharding",
) # )
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False HAS_EXLLAMA = False
@ -507,10 +507,12 @@ class TensorParallelEmbedding(nn.Module):
world_size = process_group.size() world_size = process_group.size()
rank = process_group.rank() rank = process_group.rank()
block_size = num_embeddings // world_size block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size) self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = block_size self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group self.process_group = weights.process_group
self.reduce = reduce self.reduce = reduce

View File

@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S) next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1) allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
if speculated_ids is not None: if speculated_ids is not None:
accepted_ids = [] accepted_ids = []
@ -305,16 +306,17 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype accepted_ids, device=input_ids.device, dtype=input_ids.dtype
) )
next_ids = next_ids[indices] next_ids = next_ids[indices]
scores = scores[indices] logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None: if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1] speculative_scores = speculative_scores[indices + accepted_ids - 1]
else: else:
accepted_ids = torch.ones_like(next_ids) accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate > 0:
if speculative_scores is not None: if speculative_scores is not None:
# Medusa provided some scores # Medusa provided some scores
@ -327,7 +329,7 @@ class HeterogeneousNextTokenChooser:
else: else:
speculative_ids = None speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -436,8 +438,8 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]: ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the When multiple tokens have equal probabilities and they don't all fit, the
@ -446,14 +448,19 @@ def batch_top_tokens(
max_top_n = max(top_n_tokens) max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used # Early exit when top_n_tokens is not used
if max_top_n == 0: if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
nth_highest = torch.gather( nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
) )
@ -471,13 +478,33 @@ def batch_top_tokens(
top_indices = top_k.indices.tolist() top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist() top_values = top_k.values.tolist()
return ( batch_top_token_ids = []
[ batch_top_token_logprobs = []
idxs[:n] if req_n > 0 else [] accepted_ids_list = accepted_ids.tolist()
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) for i, n_accepted_ids in enumerate(accepted_ids_list):
], start = speculate_size * i
[ stop = speculate_size * (i + 1)
vals[:n] if req_n > 0 else [] _top_indices = top_indices[start: stop]
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) _top_values = top_values[start: stop]
], _top_n_ishes = top_n_ishes[start: stop]
) _top_n_tokens = top_n_tokens[start: stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs

View File

@ -92,7 +92,7 @@ class Weights:
rank = self.process_group.rank() rank = self.process_group.rank()
size = slice_.get_shape()[dim] size = slice_.get_shape()[dim]
block_size = size // world_size block_size = (size + world_size - 1) // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size