Merge branch 'main' into impl-simple-mamba-model

This commit is contained in:
drbh 2024-02-06 18:38:20 -05:00 committed by GitHub
commit 9146ba00a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 3876 additions and 1197 deletions

View File

@ -1,12 +0,0 @@
name: Delete doc comment
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}

10
.gitignore vendored
View File

@ -2,3 +2,13 @@
target
router/tokenizer.json
*__pycache__*
# ROCm auto-generated files
*.hip
server/exllamav2_kernels/exllamav2_kernels/hip/
server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp

420
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ members = [
resolver = "2"
[workspace.package]
version = "1.3.4"
version = "1.4.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
FROM base AS kernel-builder
@ -104,6 +104,20 @@ WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
FROM base as base-copy
# Text Generation Inference base env
@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -28,7 +28,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
- [Local Install](#local-install)
- [CUDA Kernels](#cuda-kernels)
- [Optimized architectures](#optimized-architectures)
- [Run Falcon](#run-falcon)
- [Run Mistral](#run-a-model)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
@ -42,7 +42,11 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Token streaming using Server-Sent Events (SSE)
- Continuous batching of incoming requests for increased total throughput
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
- Quantization with :
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [GPT-Q](https://arxiv.org/abs/2210.17323)
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
@ -51,6 +55,14 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
### Hardware support
- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
- [Gaudi](https://github.com/huggingface/tgi-gaudi)
## Get Started
@ -62,7 +74,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
And then you can make requests like
@ -76,7 +88,7 @@ curl 127.0.0.1:8080/generate \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -106,7 +118,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
### A note on Shared Memory (shm)
@ -154,7 +166,7 @@ Python 3.9, e.g. using `conda`:
```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
conda create -n text-generation-inference python=3.9
conda create -n text-generation-inference python=3.11
conda activate text-generation-inference
```
@ -180,7 +192,7 @@ Then run:
```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
make run-falcon-7b-instruct
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
@ -189,16 +201,9 @@ make run-falcon-7b-instruct
sudo apt-get install libssl-dev gcc -y
```
### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100, AMD MI210 and AMD MI250. If you have any installation or runtime issues, you can remove
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default.
## Optimized architectures
TGI works out of the box to serve optimized models in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
Other architectures are supported on a best-effort basis using:
@ -210,12 +215,12 @@ or
## Run Falcon
## Run locally
### Run
```shell
make run-falcon-7b-instruct
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
### Quantization
@ -223,7 +228,7 @@ make run-falcon-7b-instruct
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell
make run-falcon-7b-instruct-quantize
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "1.3.4"
"version": "1.4.0"
},
"paths": {
"/": {
@ -342,6 +342,135 @@
}
}
}
},
"/tokenize": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Tokenized ids",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TokenizeResponse"
}
}
}
},
"404": {
"description": "No tokenizer found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "No fast tokenizer available"
}
}
}
}
}
}
},
"/v1/chat/completions": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionChunk"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
}
}
}
}
}
}
}
},
"components": {
@ -399,6 +528,226 @@
}
}
},
"ChatCompletion": {
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
"choices",
"usage"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionComplete"
}
},
"created": {
"type": "integer",
"format": "int64",
"example": "1706270835",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
},
"usage": {
"$ref": "#/components/schemas/Usage"
}
}
},
"ChatCompletionChoice": {
"type": "object",
"required": [
"index",
"delta"
],
"properties": {
"delta": {
"$ref": "#/components/schemas/ChatCompletionDelta"
},
"finish_reason": {
"type": "string",
"nullable": true
},
"index": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"logprobs": {
"type": "number",
"format": "float",
"nullable": true
}
}
},
"ChatCompletionChunk": {
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
"choices"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionChoice"
}
},
"created": {
"type": "integer",
"format": "int64",
"example": "1706270978",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
}
},
"ChatCompletionDelta": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"ChatRequest": {
"type": "object",
"required": [
"model"
],
"properties": {
"frequency_penalty": {
"type": "number",
"format": "float",
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"example": "1.0",
"nullable": true
},
"logit_bias": {
"type": "array",
"items": {
"type": "number",
"format": "float"
},
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"nullable": true
},
"logprobs": {
"type": "boolean",
"description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.",
"example": "false",
"nullable": true
},
"max_tokens": {
"type": "integer",
"format": "int32",
"description": "The maximum number of tokens that can be generated in the chat completion.",
"example": "32",
"nullable": true,
"minimum": 0
},
"messages": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
},
"description": "A list of messages comprising the conversation so far."
},
"model": {
"type": "string",
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"n": {
"type": "integer",
"format": "int32",
"description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.",
"example": "2",
"nullable": true,
"minimum": 0
},
"presence_penalty": {
"type": "number",
"format": "float",
"description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
"example": 0.1,
"nullable": true
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true,
"minimum": 0
},
"stream": {
"type": "boolean"
},
"temperature": {
"type": "number",
"format": "float",
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.",
"example": 1.0,
"nullable": true
},
"top_logprobs": {
"type": "integer",
"format": "int32",
"description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
"example": "5",
"nullable": true,
"minimum": 0
},
"top_p": {
"type": "number",
"format": "float",
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
}
}
},
"CompatGenerateRequest": {
"type": "object",
"required": [
@ -494,7 +843,8 @@
"length",
"eos_token",
"stop_sequence"
]
],
"example": "Length"
},
"GenerateParameters": {
"type": "object",
@ -523,7 +873,7 @@
"max_new_tokens": {
"type": "integer",
"format": "int32",
"default": "20",
"default": "100",
"example": "20",
"nullable": true,
"minimum": 0
@ -758,6 +1108,23 @@
}
}
},
"Message": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"PrefillToken": {
"type": "object",
"required": [
@ -784,6 +1151,37 @@
}
}
},
"SimpleToken": {
"type": "object",
"required": [
"id",
"text",
"start",
"stop"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0,
"minimum": 0
},
"start": {
"type": "integer",
"example": 0,
"minimum": 0
},
"stop": {
"type": "integer",
"example": 2,
"minimum": 0
},
"text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": {
"type": "object",
"required": [
@ -812,6 +1210,7 @@
"StreamResponse": {
"type": "object",
"required": [
"index",
"token"
],
"properties": {
@ -830,6 +1229,11 @@
"example": "test",
"nullable": true
},
"index": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"token": {
"$ref": "#/components/schemas/Token"
},
@ -871,6 +1275,12 @@
"example": "test"
}
}
},
"TokenizeResponse": {
"type": "array",
"items": {
"$ref": "#/components/schemas/SimpleToken"
}
}
}
},

View File

@ -7,6 +7,8 @@
title: Installation
- local: supported_models
title: Supported Models and Hardware
- local: messages_api
title: Messages API
title: Getting started
- sections:
- local: basic_tutorials/consuming_tgi

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
--model-id $model
```

View File

@ -60,9 +60,9 @@ Options:
[env: QUANTIZE=]
Possible values:
- awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
@ -354,6 +354,14 @@ Options:
[env: NGROK_EDGE=]
```
## TOKENIZER_CONFIG_PATH
```shell
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
The path to the tokenizer config file. This path is used to load the tokenizer configuration which may include a `chat_template`. If not provided, the default config will be used from the model hub
[env: TOKENIZER_CONFIG_PATH=]
```
## ENV
```shell

View File

@ -1,6 +1,6 @@
# Using TGI CLI
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli).
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).
`text-generation-server` lets you download the model with `download-weights` command like below 👇

175
docs/source/messages_api.md Normal file
View File

@ -0,0 +1,175 @@
# Messages API
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
#### Table of Contents
- [Making a Request](#making-a-request)
- [Streaming](#streaming)
- [Synchronous](#synchronous)
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
- [Cloud Providers](#cloud-providers)
- [Amazon SageMaker](#amazon-sagemaker)
## Making a Request
You can make a request to TGI's Messages API using `curl`. Here's an example:
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
## Streaming
You can also use OpenAI's Python client library to make a streaming request. Here's how:
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="-"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
```
## Synchronous
If you prefer to make a synchronous request, you can do so like this:
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="-"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
```
## Hugging Face Inference Endpoints
The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).
Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:
> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
# replace with your endpoint url, make sure to include "v1/" at the end
base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/",
# replace with your API key
api_key="hf_XXX"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message.choices[0].delta.content, end="")
```
## Cloud Providers
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
```python
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
container_startup_health_check_timeout=300,
)
# send request
predictor.predict({
"messages": [
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
]
})
```

View File

@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
<Tip warning={true}>
@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
```bash
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
```
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:1.3 --help
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
```
</Tip>

View File

@ -19,7 +19,9 @@ The following models are optimized and can be served with TGI, which uses custom
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Phi](https://huggingface.co/microsoft/phi-2)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
@ -41,8 +43,8 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Quantization (GPTQ, AWQ, etc.)
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
* Kernel for slinding window attention (Mistral)

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.76660156,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7246094,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.41333008,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.11785889,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.97265625,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.0569458,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}

View File

@ -0,0 +1,60 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 6,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 284,
"logprob": -0.19421387,
"special": false,
"text": " to"
},
{
"id": 3758,
"logprob": -0.62597656,
"special": false,
"text": " send"
},
{
"id": 1366,
"logprob": -0.87060547,
"special": false,
"text": " data"
},
{
"id": 625,
"logprob": -0.88427734,
"special": false,
"text": " over"
},
{
"id": 257,
"logprob": -1.0830078,
"special": false,
"text": " a"
},
{
"id": 3127,
"logprob": -1.9462891,
"special": false,
"text": " network"
}
],
"top_tokens": null
},
"generated_text": "Test request to send data over a network"
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}
]

View File

@ -16,52 +16,52 @@
},
{
"id": 21017,
"logprob": -9.09375,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25976562,
"logprob": -0.25830078,
"text": "_"
},
{
"id": 6009,
"logprob": -2.2148438,
"logprob": -2.1875,
"text": "mean"
},
{
"id": 26,
"logprob": -0.3010254,
"logprob": -0.30004883,
"text": "("
},
{
"id": 62,
"logprob": -5.6757812,
"logprob": -5.6171875,
"text": "L"
},
{
"id": 44,
"logprob": -3.0898438,
"logprob": -3.078125,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6791992,
"logprob": -0.68066406,
"text": " List"
},
{
"id": 77,
"logprob": -0.38891602,
"logprob": -0.38745117,
"text": "["
},
{
"id": 1808,
"logprob": -0.92041016,
"logprob": -0.9453125,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5390625,
"logprob": -2.5371094,
"text": "]):"
}
],
@ -69,7 +69,7 @@
"tokens": [
{
"id": 284,
"logprob": 0.0,
"logprob": -0.051635742,
"special": false,
"text": "\n "
},
@ -81,7 +81,7 @@
},
{
"id": 11665,
"logprob": -1.6005859,
"logprob": -1.2236328,
"special": false,
"text": " reduce"
},
@ -159,7 +159,7 @@
},
{
"id": 203,
"logprob": -0.11968994,
"logprob": -0.12695312,
"special": false,
"text": "\n"
},

View File

@ -11,92 +11,92 @@
},
{
"id": 4911,
"logprob": -5.7851562,
"logprob": -6.9765625,
"text": "User"
},
{
"id": 29901,
"logprob": -0.006996155,
"logprob": -0.0059432983,
"text": ":"
},
{
"id": 32000,
"logprob": -0.81347656,
"logprob": -0.8408203,
"text": "<fake_token_around_image>"
},
{
"id": 32001,
"logprob": -6.687641e-05,
"logprob": -9.906292e-05,
"text": "<image>"
},
{
"id": 32000,
"logprob": -3.5762787e-07,
"logprob": -2.3841858e-07,
"text": "<fake_token_around_image>"
},
{
"id": 1815,
"logprob": -4.2148438,
"logprob": -4.1679688,
"text": "Can"
},
{
"id": 366,
"logprob": -0.014137268,
"logprob": -0.014099121,
"text": "you"
},
{
"id": 2649,
"logprob": -4.4335938,
"logprob": -4.4609375,
"text": "tell"
},
{
"id": 592,
"logprob": -0.2919922,
"logprob": -0.29882812,
"text": "me"
},
{
"id": 263,
"logprob": -4.2070312,
"logprob": -4.1445312,
"text": "a"
},
{
"id": 1407,
"logprob": -9.421875,
"logprob": -9.3828125,
"text": "very"
},
{
"id": 3273,
"logprob": -1.8720703,
"logprob": -1.9736328,
"text": "short"
},
{
"id": 5828,
"logprob": -0.26489258,
"logprob": -0.2800293,
"text": "story"
},
{
"id": 2729,
"logprob": -3.7441406,
"logprob": -3.5625,
"text": "based"
},
{
"id": 373,
"logprob": -0.0005393028,
"logprob": -0.0006427765,
"text": "on"
},
{
"id": 278,
"logprob": -0.140625,
"logprob": -0.13952637,
"text": "the"
},
{
"id": 1967,
"logprob": -0.06756592,
"logprob": -0.068115234,
"text": "image"
},
{
"id": 29973,
"logprob": -0.15454102,
"logprob": -0.16357422,
"text": "?"
}
],
@ -104,25 +104,25 @@
"tokens": [
{
"id": 32002,
"logprob": -0.0019140244,
"logprob": -0.0026474,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 29871,
"logprob": -8.404255e-05,
"logprob": -8.547306e-05,
"special": false,
"text": " "
},
{
"id": 13,
"logprob": -1.7642975e-05,
"logprob": -1.7881393e-05,
"special": false,
"text": "\n"
},
{
"id": 7900,
"logprob": -2.9802322e-06,
"logprob": -3.0994415e-06,
"special": false,
"text": "Ass"
},
@ -140,30 +140,29 @@
},
{
"id": 319,
"logprob": -0.91064453,
"logprob": -0.92529297,
"special": false,
"text": " A"
},
{
"id": 696,
"logprob": -1.2412109,
"logprob": -1.1269531,
"special": false,
"text": " ro"
},
{
"id": 15664,
"logprob": -0.0002439022,
"logprob": -0.00029492378,
"special": false,
"text": "oster"
},
{
"id": 15028,
"logprob": -1.1630859,
"logprob": -1.1855469,
"special": false,
"text": " stands"
}
],
"top_tokens": null
]
},
"generated_text": " \nAssistant: A rooster stands"
}

View File

@ -12,92 +12,92 @@
},
{
"id": 4911,
"logprob": -5.7851562,
"logprob": -6.9804688,
"text": "User"
},
{
"id": 29901,
"logprob": -0.006996155,
"logprob": -0.006122589,
"text": ":"
},
{
"id": 32000,
"logprob": -0.81347656,
"logprob": -0.8417969,
"text": "<fake_token_around_image>"
},
{
"id": 32001,
"logprob": -6.687641e-05,
"logprob": -9.918213e-05,
"text": "<image>"
},
{
"id": 32000,
"logprob": -3.5762787e-07,
"logprob": -2.3841858e-07,
"text": "<fake_token_around_image>"
},
{
"id": 1815,
"logprob": -4.2148438,
"logprob": -4.1679688,
"text": "Can"
},
{
"id": 366,
"logprob": -0.014137268,
"logprob": -0.014091492,
"text": "you"
},
{
"id": 2649,
"logprob": -4.4335938,
"logprob": -4.4726562,
"text": "tell"
},
{
"id": 592,
"logprob": -0.2919922,
"logprob": -0.2998047,
"text": "me"
},
{
"id": 263,
"logprob": -4.2070312,
"logprob": -4.15625,
"text": "a"
},
{
"id": 1407,
"logprob": -9.421875,
"logprob": -9.3828125,
"text": "very"
},
{
"id": 3273,
"logprob": -1.8720703,
"logprob": -1.9716797,
"text": "short"
},
{
"id": 5828,
"logprob": -0.26489258,
"logprob": -0.27734375,
"text": "story"
},
{
"id": 2729,
"logprob": -3.7441406,
"logprob": -3.5605469,
"text": "based"
},
{
"id": 373,
"logprob": -0.0005393028,
"logprob": -0.00064468384,
"text": "on"
},
{
"id": 278,
"logprob": -0.140625,
"logprob": -0.14160156,
"text": "the"
},
{
"id": 1967,
"logprob": -0.06756592,
"logprob": -0.06915283,
"text": "image"
},
{
"id": 29973,
"logprob": -0.15454102,
"logprob": -0.16381836,
"text": "?"
}
],
@ -105,19 +105,19 @@
"tokens": [
{
"id": 32002,
"logprob": -0.0019140244,
"logprob": -0.0026664734,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 29871,
"logprob": -8.392334e-05,
"logprob": -8.583069e-05,
"special": false,
"text": " "
},
{
"id": 13,
"logprob": -1.7881393e-05,
"logprob": -1.8119812e-05,
"special": false,
"text": "\n"
},
@ -135,36 +135,35 @@
},
{
"id": 29901,
"logprob": -3.0994415e-06,
"logprob": -3.2186508e-06,
"special": false,
"text": ":"
},
{
"id": 319,
"logprob": -0.9057617,
"logprob": -0.9301758,
"special": false,
"text": " A"
},
{
"id": 696,
"logprob": -1.2294922,
"logprob": -1.1279297,
"special": false,
"text": " ro"
},
{
"id": 15664,
"logprob": -0.00024533272,
"logprob": -0.0002939701,
"special": false,
"text": "oster"
},
{
"id": 15028,
"logprob": -1.1640625,
"logprob": -1.1865234,
"special": false,
"text": " stands"
}
],
"top_tokens": null
]
},
"generated_text": " \nAssistant: A rooster stands"
},
@ -181,92 +180,92 @@
},
{
"id": 4911,
"logprob": -5.7773438,
"logprob": -6.9804688,
"text": "User"
},
{
"id": 29901,
"logprob": -0.0070114136,
"logprob": -0.006122589,
"text": ":"
},
{
"id": 32000,
"logprob": -0.8208008,
"logprob": -0.8417969,
"text": "<fake_token_around_image>"
},
{
"id": 32001,
"logprob": -6.699562e-05,
"logprob": -9.942055e-05,
"text": "<image>"
},
{
"id": 32000,
"logprob": -3.5762787e-07,
"logprob": -2.3841858e-07,
"text": "<fake_token_around_image>"
},
{
"id": 1815,
"logprob": -4.2265625,
"logprob": -4.1679688,
"text": "Can"
},
{
"id": 366,
"logprob": -0.014175415,
"logprob": -0.014091492,
"text": "you"
},
{
"id": 2649,
"logprob": -4.4296875,
"logprob": -4.4726562,
"text": "tell"
},
{
"id": 592,
"logprob": -0.29516602,
"logprob": -0.2998047,
"text": "me"
},
{
"id": 263,
"logprob": -4.2109375,
"logprob": -4.15625,
"text": "a"
},
{
"id": 1407,
"logprob": -9.4296875,
"logprob": -9.3828125,
"text": "very"
},
{
"id": 3273,
"logprob": -1.8720703,
"logprob": -1.9716797,
"text": "short"
},
{
"id": 5828,
"logprob": -0.26879883,
"logprob": -0.27734375,
"text": "story"
},
{
"id": 2729,
"logprob": -3.7675781,
"logprob": -3.5605469,
"text": "based"
},
{
"id": 373,
"logprob": -0.0005354881,
"logprob": -0.0006451607,
"text": "on"
},
{
"id": 278,
"logprob": -0.13671875,
"logprob": -0.14160156,
"text": "the"
},
{
"id": 1967,
"logprob": -0.06719971,
"logprob": -0.06915283,
"text": "image"
},
{
"id": 29973,
"logprob": -0.15551758,
"logprob": -0.16381836,
"text": "?"
}
],
@ -274,19 +273,19 @@
"tokens": [
{
"id": 32002,
"logprob": -0.0019130707,
"logprob": -0.0026664734,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 29871,
"logprob": -8.392334e-05,
"logprob": -8.571148e-05,
"special": false,
"text": " "
},
{
"id": 13,
"logprob": -1.7881393e-05,
"logprob": -1.8119812e-05,
"special": false,
"text": "\n"
},
@ -310,30 +309,29 @@
},
{
"id": 319,
"logprob": -0.9013672,
"logprob": -0.9301758,
"special": false,
"text": " A"
},
{
"id": 696,
"logprob": -1.2324219,
"logprob": -1.1279297,
"special": false,
"text": " ro"
},
{
"id": 15664,
"logprob": -0.0002477169,
"logprob": -0.0002939701,
"special": false,
"text": "oster"
},
{
"id": 15028,
"logprob": -1.1660156,
"logprob": -1.1865234,
"special": false,
"text": " stands"
}
],
"top_tokens": null
]
},
"generated_text": " \nAssistant: A rooster stands"
},
@ -350,92 +348,92 @@
},
{
"id": 4911,
"logprob": -5.7773438,
"logprob": -6.9804688,
"text": "User"
},
{
"id": 29901,
"logprob": -0.0070114136,
"logprob": -0.006122589,
"text": ":"
},
{
"id": 32000,
"logprob": -0.8208008,
"logprob": -0.8417969,
"text": "<fake_token_around_image>"
},
{
"id": 32001,
"logprob": -6.699562e-05,
"logprob": -9.918213e-05,
"text": "<image>"
},
{
"id": 32000,
"logprob": -3.5762787e-07,
"logprob": -2.3841858e-07,
"text": "<fake_token_around_image>"
},
{
"id": 1815,
"logprob": -4.2265625,
"logprob": -4.1679688,
"text": "Can"
},
{
"id": 366,
"logprob": -0.014175415,
"logprob": -0.014091492,
"text": "you"
},
{
"id": 2649,
"logprob": -4.4296875,
"logprob": -4.4726562,
"text": "tell"
},
{
"id": 592,
"logprob": -0.29516602,
"logprob": -0.2998047,
"text": "me"
},
{
"id": 263,
"logprob": -4.2109375,
"logprob": -4.15625,
"text": "a"
},
{
"id": 1407,
"logprob": -9.4296875,
"logprob": -9.3828125,
"text": "very"
},
{
"id": 3273,
"logprob": -1.8720703,
"logprob": -1.9716797,
"text": "short"
},
{
"id": 5828,
"logprob": -0.26879883,
"logprob": -0.27734375,
"text": "story"
},
{
"id": 2729,
"logprob": -3.7675781,
"logprob": -3.5605469,
"text": "based"
},
{
"id": 373,
"logprob": -0.0005354881,
"logprob": -0.00064468384,
"text": "on"
},
{
"id": 278,
"logprob": -0.13671875,
"logprob": -0.14160156,
"text": "the"
},
{
"id": 1967,
"logprob": -0.06719971,
"logprob": -0.06915283,
"text": "image"
},
{
"id": 29973,
"logprob": -0.15551758,
"logprob": -0.16381836,
"text": "?"
}
],
@ -443,19 +441,19 @@
"tokens": [
{
"id": 32002,
"logprob": -0.001912117,
"logprob": -0.0026664734,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 29871,
"logprob": -8.392334e-05,
"logprob": -8.59499e-05,
"special": false,
"text": " "
},
{
"id": 13,
"logprob": -1.7762184e-05,
"logprob": -1.8119812e-05,
"special": false,
"text": "\n"
},
@ -479,30 +477,29 @@
},
{
"id": 319,
"logprob": -0.9013672,
"logprob": -0.9301758,
"special": false,
"text": " A"
},
{
"id": 696,
"logprob": -1.2324219,
"logprob": -1.1279297,
"special": false,
"text": " ro"
},
{
"id": 15664,
"logprob": -0.0002477169,
"logprob": -0.0002939701,
"special": false,
"text": "oster"
},
{
"id": 15028,
"logprob": -1.1660156,
"logprob": -1.1865234,
"special": false,
"text": " stands"
}
],
"top_tokens": null
]
},
"generated_text": " \nAssistant: A rooster stands"
},
@ -519,92 +516,92 @@
},
{
"id": 4911,
"logprob": -5.7773438,
"logprob": -6.9804688,
"text": "User"
},
{
"id": 29901,
"logprob": -0.0070114136,
"logprob": -0.006122589,
"text": ":"
},
{
"id": 32000,
"logprob": -0.8208008,
"logprob": -0.8417969,
"text": "<fake_token_around_image>"
},
{
"id": 32001,
"logprob": -6.699562e-05,
"logprob": -9.942055e-05,
"text": "<image>"
},
{
"id": 32000,
"logprob": -3.5762787e-07,
"logprob": -2.3841858e-07,
"text": "<fake_token_around_image>"
},
{
"id": 1815,
"logprob": -4.2265625,
"logprob": -4.1679688,
"text": "Can"
},
{
"id": 366,
"logprob": -0.014175415,
"logprob": -0.014091492,
"text": "you"
},
{
"id": 2649,
"logprob": -4.4296875,
"logprob": -4.4726562,
"text": "tell"
},
{
"id": 592,
"logprob": -0.29516602,
"logprob": -0.2998047,
"text": "me"
},
{
"id": 263,
"logprob": -4.2109375,
"logprob": -4.15625,
"text": "a"
},
{
"id": 1407,
"logprob": -9.4296875,
"logprob": -9.3828125,
"text": "very"
},
{
"id": 3273,
"logprob": -1.8720703,
"logprob": -1.9716797,
"text": "short"
},
{
"id": 5828,
"logprob": -0.26879883,
"logprob": -0.27734375,
"text": "story"
},
{
"id": 2729,
"logprob": -3.7675781,
"logprob": -3.5605469,
"text": "based"
},
{
"id": 373,
"logprob": -0.0005354881,
"logprob": -0.0006451607,
"text": "on"
},
{
"id": 278,
"logprob": -0.13671875,
"logprob": -0.14160156,
"text": "the"
},
{
"id": 1967,
"logprob": -0.06719971,
"logprob": -0.06915283,
"text": "image"
},
{
"id": 29973,
"logprob": -0.15551758,
"logprob": -0.16381836,
"text": "?"
}
],
@ -612,19 +609,19 @@
"tokens": [
{
"id": 32002,
"logprob": -0.001912117,
"logprob": -0.0026664734,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 29871,
"logprob": -8.392334e-05,
"logprob": -8.571148e-05,
"special": false,
"text": " "
},
{
"id": 13,
"logprob": -1.7762184e-05,
"logprob": -1.8119812e-05,
"special": false,
"text": "\n"
},
@ -648,30 +645,29 @@
},
{
"id": 319,
"logprob": -0.9013672,
"logprob": -0.9301758,
"special": false,
"text": " A"
},
{
"id": 696,
"logprob": -1.2324219,
"logprob": -1.1279297,
"special": false,
"text": " ro"
},
{
"id": 15664,
"logprob": -0.0002477169,
"logprob": -0.0002939701,
"special": false,
"text": "oster"
},
{
"id": 15028,
"logprob": -1.1660156,
"logprob": -1.1865234,
"special": false,
"text": " stands"
}
],
"top_tokens": null
]
},
"generated_text": " \nAssistant: A rooster stands"
}

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi_handle(launcher):
with launcher("microsoft/phi-2", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi(flash_phi_handle):
await flash_phi_handle.health(300)
return flash_phi_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response.generated_text == ': {request}")\n response = self'
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["network"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response.generated_text == "Test request to send data over a network"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ': {request}")\n response = self'
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
version = "1.3.4"
version = "1.4.0"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines, Read};
use std::io::{BufRead, BufReader, Lines};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
@ -21,16 +21,16 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
/// 4 bit quantization. Requires a specific GTPQ quantized model:
/// 4 bit quantization. Requires a specific AWQ quantized model:
/// https://hf.co/models?search=awq.
/// Should replace GPTQ models whereever possible because of the better latency
/// Should replace GPTQ models wherever possible because of the better latency
Awq,
/// 8 bit quantization, doesn't require specific model.
/// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
Eetq,
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
/// text-generation-inference will use exllama (faster) kernels whereever possible, and use
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not.
/// AWQ has faster kernels.
Gptq,
@ -368,6 +368,11 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
#[clap(long, env)]
tokenizer_config_path: Option<String>,
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
@ -489,6 +494,9 @@ fn shard_manager(
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
@ -573,6 +581,13 @@ fn shard_manager(
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut ready = false;
let start_time = Instant::now();
@ -580,13 +595,6 @@ fn shard_manager(
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
@ -782,6 +790,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
@ -832,12 +843,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(stdout.lines());
log_lines(download_stdout.lines());
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in download_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
loop {
@ -848,12 +867,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
if let Some(signal) = status.signal() {
tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}"
@ -965,7 +982,20 @@ fn spawn_shards(
Ok(())
}
fn compute_type(num_shard: usize) -> Option<String> {
let output = Command::new("nvidia-smi")
.args(["--query-gpu=gpu_name", "--format=csv"])
.output()
.ok()?;
let output = String::from_utf8(output.stdout).ok()?;
let fullname = output.split('\n').nth(1)?;
let cardname = fullname.replace(' ', "-").to_lowercase();
let compute_type = format!("{num_shard}-{cardname}");
Some(compute_type)
}
fn spawn_webserver(
num_shard: usize,
args: Args,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
@ -1004,6 +1034,12 @@ fn spawn_webserver(
args.model_id,
];
// Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string());
router_args.push(tokenizer_config_path.to_string());
}
// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
@ -1049,6 +1085,13 @@ fn spawn_webserver(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// Parse Compute type
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
} else if let Some(compute_type) = compute_type(num_shard) {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}
let mut webserver = match Command::new("text-generation-router")
.args(router_args)
.envs(envs)
@ -1242,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
return Ok(());
}
let mut webserver =
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
.map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;

View File

@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] }
tokenizers = { version = "0.15.1", features = ["http"] }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] }

View File

@ -165,6 +165,28 @@ impl Infer {
))
}
/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
err
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
}
impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self {
pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default()
}
@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
}
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion {
pub id: String,
pub object: String,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionComplete>,
pub usage: Usage,
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
@ -248,17 +250,19 @@ impl ChatCompletion {
}
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
pub object: String,
#[schema(example = "1706270978")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
#[schema(example = "What is Deep Learning?")]
pub content: String,
}
@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct ChatRequest {
/// UNUSED
#[schema(example = "bigscience/blomm-560m")]
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */
@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
#[serde(default = "bool::default")]
@ -365,6 +377,20 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
}
#[derive(Clone, Serialize, Deserialize)]
@ -432,8 +458,21 @@ pub struct Token {
special: bool,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(example = 0)]
start: usize,
#[schema(example = 2)]
stop: usize,
}
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
@ -494,6 +533,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}
#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
@ -524,26 +567,12 @@ pub(crate) struct ErrorResponse {
#[cfg(test)]
mod tests {
use std::io::Write;
use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer {
let filename = std::path::Path::new("tokenizer.json");
if !filename.exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
.await
.unwrap()
.bytes()
.await
.unwrap();
let tmp_filename = "tokenizer.json.temp";
let mut file = std::fs::File::create(tmp_filename).unwrap();
file.write_all(&content).unwrap();
// Re-check if another process has written this file maybe.
if !filename.exists() {
std::fs::rename(tmp_filename, filename).unwrap()
}
}
Tokenizer::from_file("tokenizer.json").unwrap()
let api = hf_hub::api::sync::Api::new().unwrap();
let repo = api.model("gpt2".to_string());
let filename = repo.get("tokenizer.json").unwrap();
Tokenizer::from_file(filename).unwrap()
}
}

View File

@ -72,7 +72,7 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
chat_enabled_api: bool,
messages_api_enabled: bool,
}
#[tokio::main]
@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
chat_enabled_api,
messages_api_enabled,
} = args;
// Launch Tokio runtime
@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Load tokenizer config
// This will be used to format the chat template
let local_tokenizer_config_path =
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
};
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config {
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
} else if let Some(api) = api {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()),
)))
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig::default()
})
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
tracing::warn!("Could not find tokenizer config locally and no revision specified");
HubTokenizerConfig::default()
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
}
};
if tokenizer.is_none() {
@ -348,7 +353,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
chat_enabled_api,
messages_api_enabled,
)
.await?;
Ok(())
@ -462,7 +467,12 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader).ok()?;
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}

View File

@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
Token, Validation,
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})),
async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag
@ -66,11 +67,11 @@ async fn compat_generate(
// switch on stream
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
Ok(generate_stream(infer, compute_type, Json(req.into()))
.await
.into_response())
} else {
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
}
@ -145,6 +146,7 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
Extension(ComputeType(compute_type)): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
@ -230,7 +232,7 @@ async fn generate(
// Headers
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert(
"x-compute-time",
total_time.as_millis().to_string().parse().unwrap(),
@ -339,6 +341,7 @@ seed,
)]
async fn generate_stream(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<GenerateRequest>,
) -> (
HeaderMap,
@ -349,13 +352,14 @@ async fn generate_stream(
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(req), on_message_callback).await;
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse)
}
async fn generate_stream_internal(
infer: Infer,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
@ -368,7 +372,7 @@ async fn generate_stream_internal(
let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
@ -532,7 +536,7 @@ async fn generate_stream_internal(
path = "/v1/chat/completions",
request_body = ChatRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
@ -557,6 +561,7 @@ async fn generate_stream_internal(
)]
async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -592,10 +597,10 @@ async fn chat_completions(
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
temperature: req.temperature,
repetition_penalty,
top_k: None,
top_p: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
max_new_tokens,
@ -604,7 +609,7 @@ async fn chat_completions(
truncate: None,
watermark: false,
details: true,
decoder_input_details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
},
@ -644,13 +649,22 @@ async fn chat_completions(
)
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(generate_request), on_message_callback).await;
let (headers, response_stream) = generate_stream_internal(
infer,
compute_type,
Json(generate_request),
on_message_callback,
)
.await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) =
generate(Extension(infer), Json(generate_request)).await?;
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -672,6 +686,52 @@ async fn chat_completions(
}
}
/// Tokenize inputs
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/tokenize",
request_body = GenerateRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
example = json ! ({"error": "No fast tokenizer available"})),
)
)]
#[instrument(skip_all)]
async fn tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text: String = input.chars().skip(start).take(stop - start).collect();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
@ -683,6 +743,9 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
}
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
@ -708,7 +771,7 @@ pub async fn run(
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
chat_enabled_api: bool,
messages_api_enabled: bool,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -719,6 +782,8 @@ pub async fn run(
compat_generate,
generate,
generate_stream,
chat_completions,
tokenize,
metrics,
),
components(
@ -726,10 +791,18 @@ pub async fn run(
Info,
CompatGenerateRequest,
GenerateRequest,
ChatRequest,
Message,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletion,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,
@ -863,21 +936,26 @@ pub async fn run(
// Define base and health routes
let base_routes = Router::new()
.route("/", post(compat_generate))
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions))
.route("/tokenize", post(tokenize))
.route("/health", get(health))
.route("/ping", get(health))
.route("/metrics", get(metrics));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if chat_enabled_api {
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers
let app = Router::new()
.merge(swagger_ui)
@ -887,6 +965,7 @@ pub async fn run(
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);

View File

@ -70,12 +70,11 @@ impl Validation {
}
#[instrument(skip(self, inputs))]
async fn validate_input(
pub async fn tokenize(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> {
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -88,7 +87,24 @@ impl Validation {
// Await on response channel
// Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?;
let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
}
#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
let input_length = encoding.len();
// Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
@ -343,36 +359,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<
/// Get input length and optionally truncate it
fn prepare_input(
inputs: String,
mut inputs: String,
truncate: Option<usize>,
tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> {
) -> Result<(tokenizers::Encoding, String), ValidationError> {
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate
let (inputs, input_length) = match truncate {
// Truncate is some and < encoding length
Some(truncate) if truncate < encoding.len() => {
// truncate encoding and decode new inputs
if let Some(truncate) = truncate {
if truncate < encoding.len() {
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
}
// Nothing to do
_ => (inputs, encoding.len()),
};
}
Ok((inputs, input_length))
Ok((encoding, inputs))
}
type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
Span,
);

View File

@ -1,4 +1,4 @@
eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1
eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9
eetq:
# Clone eetq
@ -6,7 +6,7 @@ eetq:
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit)
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
cd eetq && python setup.py build
install-eetq: build-eetq

View File

@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
//
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ < 700
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

View File

@ -2,8 +2,11 @@
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#include "../cu_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
if constexpr (use_half2)
{
half result = __hadd(acc.x, acc.y);
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else

View File

@ -1,12 +1,23 @@
#ifndef _compat_gemm_cuh
#define _compat_gemm_cuh
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#if defined(USE_ROCM)
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
// for symbols as hipblasHalf.
#include <hipblas/hipblas.h>
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif
#endif

View File

@ -8,7 +8,11 @@
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess

View File

@ -1,5 +1,15 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = ["-lineinfo", "-O3"]
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_compile_args = {
"nvcc": extra_cuda_cflags,
}
setup(
name="exllamav2_kernels",
@ -11,6 +21,7 @@ setup(
"exllamav2_kernels/cuda/q_matrix.cu",
"exllamav2_kernels/cuda/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},

1193
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "1.3.4"
version = "1.4.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.15.0"
huggingface-hub = "^0.19.3"
transformers = "^4.36.1"
transformers = "^4.37.1"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }

View File

@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,77 @@
import torch
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self) -> int:
return self.world_size
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
world_size = self.process_group.size()
size = self.weight.shape[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
return self.weight[start:stop]
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
output = embeddings.forward(input_ids)
assert embeddings.min_id == 0
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -50,19 +50,39 @@ def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
accepted_ids = torch.ones_like(top_n_tokens_tensor)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

View File

@ -19,6 +19,7 @@ from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.mamba import Mamba
from text_generation_server.models.phi import Phi
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@ -58,6 +59,7 @@ try:
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e:
@ -73,6 +75,7 @@ if FLASH_ATTENTION:
__all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashPhi)
def get_model(
@ -247,6 +250,39 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "phi":
if FLASH_ATTENTION:
return FlashPhi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
else:
return Phi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "llama" or model_type == "baichuan":
if FLASH_ATTENTION:
return FlashLlama(

View File

@ -580,10 +580,13 @@ class CausalLM(Model):
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
@ -692,20 +695,24 @@ class CausalLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_dim = int(config.rotary_pct * self.head_size)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module):
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.rotary_dim,
base=config.rotary_emb_base,
device=weights.device,
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -0,0 +1,410 @@
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastLayerNorm,
)
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
hidden_size=2560,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="gelu_fast", # llama uses silu
layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.rotary_dim,
base=config.rope_theta,
device=weights.device,
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=True,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
# Reshape query and key for rotary embeddings
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# NOTE: this is the main difference between Llama and Phi
# in llama the rotary embeddings are applied to the whole query and key.
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=True,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=True,
)
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
return self.down_proj(self.act(self.up_proj(hidden_states)))
class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
return self.lm_head(hidden_states)

View File

@ -28,7 +28,6 @@ EPS = 1e-5
def load_col(config, prefix, weights, bias):
assert bias == False, NotImplementedError
assert config.quantize != "gptq", NotImplementedError
slice_ = weights._get_slice(f"{prefix}.weight")
rank = weights.process_group.rank()
@ -45,7 +44,36 @@ def load_col(config, prefix, weights, bias):
if weight.dtype != torch.int32:
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
bias = None
if bias:
bias_slice_ = weights._get_slice(f"{prefix}.bias")
bias_rank = weights.process_group.rank()
bias_size = weights.process_group.size()
bias_h = bias_slice_.get_shape()
bias_h = bias_h[0]
bias_block_size = bias_h // bias_size
bias_q_part = bias_slice_[
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
]
bias_k_part = bias_slice_[
bias_h
+ bias_rank * bias_block_size : bias_h
+ (bias_rank + 1) * bias_block_size
]
bias_v_part = bias_slice_[
2 * bias_h
+ bias_rank * bias_block_size : 2 * bias_h
+ (bias_rank + 1) * bias_block_size
]
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
if bias.dtype != torch.int32:
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return TensorParallelColumnLinear(linear)
@ -330,7 +358,16 @@ class MultiheadAttention(nn.Module):
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
if self.qk_ln:
raise NotImplementedError("qk_ln is not supported")
bias = not config.no_bias
hidden_size = config.d_model
head_dim = hidden_size // self.n_heads
self.q_ln = LPLayerNorm(
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
)
self.k_ln = LPLayerNorm(
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
)
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -581,12 +618,20 @@ class MPTBlock(nn.Module):
f"""Not implemented attn {config.attn_config["attn_type"]}"""
)
resid_pdrop = config.resid_pdrop
self.norm_1 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
if config.no_bias:
self.norm_1 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
else:
self.norm_1 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@ -635,6 +680,9 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=True,
device=None,
dtype=None,
bias: Optional[bool] = True,
prefix=None,
weights=None,
):
super().__init__(
normalized_shape=normalized_shape,
@ -642,7 +690,13 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
bias=bias,
)
if weights is not None:
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
if bias:
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
self.normalized_shape = self.weight.shape
def forward(self, x):
module_device = x.device
@ -755,20 +809,23 @@ class MPTModel(MPTPreTrainedModel):
)
self.wte = TensorParallelEmbedding("transformer.wte", weights)
if not self.alibi:
# self.wpe = torch.nn.Embedding(
# config.max_seq_len, config.d_model, device=config.init_device
# )
raise RuntimeError("no alibi no supported")
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
for i in range(config.n_layers)
]
)
self.norm_f = nn.LayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
if config.no_bias:
self.norm_f = nn.LayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
else:
self.norm_f = nn.LayerNorm.load(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False
self.attn_bias = None
@ -787,8 +844,9 @@ class MPTModel(MPTPreTrainedModel):
if config.verbose:
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
module.register_parameter("bias", None)
if config.verbose and config.verbose > 2:
print(self)
if hasattr(self.config, "verbose"):
if config.verbose and config.verbose > 2:
print(self)
if "verbose" not in self.config.init_config:
self.config.init_config["verbose"] = self.config.verbose
if self.config.init_config["verbose"] > 1:

View File

@ -0,0 +1,330 @@
# imlementation of the PhiModel and PhiForCausalLM classes
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
)
# PhiConfig is the configuration class for the PhiModel.
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
freqs = t.matmul(inv_freq)
self.sin = freqs.sin()
self.cos = freqs.cos()
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
b_size, seqlen, three, _, _headdim = qkv.shape
if three != 3:
raise Exception("unexpected shape for qkv")
_, rotary_dim = self.cos.shape
rotary_dim = rotary_dim * 2
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q12 = torch.chunk(q_rot, 2, dim=-1)
k12 = torch.chunk(k_rot, 2, dim=-1)
q1, q2 = q12[0], q12[1]
k1, k2 = k12[0], k12[1]
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
q_rot = torch.cat(
[
q1 * c - q2 * s,
q1 * s + q2 * c,
],
dim=-1,
)
k_rot = torch.cat(
[
k1 * c - k2 * s,
k1 * s + k2 * c,
],
dim=-1,
)
q = torch.cat([q_rot, q_pass], dim=-1)
k = torch.cat([k_rot, k_pass], dim=-1)
v = qkv[:, :, 2]
return q, k, v
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
class PhiCausalLMHead(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.ln = nn.LayerNorm.load(
prefix="lm_head.ln",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.linear = TensorParallelHead.load(
config=config, prefix="lm_head.linear", weights=weights
)
def forward(self, hidden_states):
hidden_states = self.ln(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
self.op_size = config.n_embd
self.head_dim = int(config.n_embd / config.n_head)
self.num_heads = config.n_head
self.rotary_emb = RotaryEmbedding(
config.rotary_dim,
config.n_positions,
)
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
b_size, seq_len, _n_embd = hidden_states.shape
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
# if there is a kv_cache, then we need to concatenate
if past_kv_cache is not None:
prev_k, prev_v = past_kv_cache
k = torch.cat([prev_k, k], dim=1)
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
1,
)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = (
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
.transpose(1, 2)
.flatten(-2)
)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.n_inner = config.n_inner
self.fc1 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=False,
)
self.fc2 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=False,
)
self.activation = torch.nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(
hidden_states, kv_cache, attention_mask
)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = (
[None] * len(self.blocks) if past_key_values is None else past_key_values
)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(
hidden_states, past_key_values[index], mask
)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
return (
((loss,) + (logits,) + model_output[1:])
if loss is not None
else (logits,) + model_output[1:]
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output[1],
hidden_states=None,
attentions=None,
)

View File

@ -842,6 +842,8 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
@ -851,16 +853,15 @@ class FlashCausalLM(Model):
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
speculate,
batch.speculative_ids,
speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1062,20 +1063,24 @@ class FlashCausalLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -74,9 +74,9 @@ class FlashLlama(FlashCausalLM):
import os
from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(

View File

@ -0,0 +1,102 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
PhiConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -0,0 +1,66 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -640,10 +640,13 @@ class Seq2SeqLM(Model):
batch.past_key_values,
)
# Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
@ -746,20 +749,24 @@ class Seq2SeqLM(Model):
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None

View File

@ -95,5 +95,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None,
)

View File

@ -1,12 +1,9 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from logging import getLogger
import torch
import torch.nn as nn
import math
logger = getLogger(__name__)
from loguru import logger
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
@ -185,6 +182,10 @@ class QuantLinear(nn.Module):
"g_idx": self.g_idx,
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
def forward(self, x, force_cuda=False):

View File

@ -33,14 +33,14 @@ except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding",
)
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
# V2 = False
# log_once(
# logger.warning,
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
# )
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
@ -507,10 +507,12 @@ class TensorParallelEmbedding(nn.Module):
world_size = process_group.size()
rank = process_group.rank()
block_size = num_embeddings // world_size
block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = block_size
self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group
self.reduce = reduce

View File

@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser:
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1)
allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
if speculated_ids is not None:
accepted_ids = []
@ -305,16 +306,17 @@ class HeterogeneousNextTokenChooser:
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
scores = scores[indices]
logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
@ -327,7 +329,7 @@ class HeterogeneousNextTokenChooser:
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def filter(self, indices):
if self.watermark_processor is not None:
@ -436,8 +438,8 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
@ -446,14 +448,19 @@ def batch_top_tokens(
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
@ -471,13 +478,33 @@ def batch_top_tokens(
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)
batch_top_token_ids = []
batch_top_token_logprobs = []
accepted_ids_list = accepted_ids.tolist()
for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i
stop = speculate_size * (i + 1)
_top_indices = top_indices[start: stop]
_top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs

View File

@ -92,7 +92,7 @@ class Weights:
rank = self.process_group.rank()
size = slice_.get_shape()[dim]
block_size = size // world_size
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size