Merge branch 'gptq-cuda-kernels' of https://github.com/fxmarty/text-generation-inference into gptq-cuda-kernels

This commit is contained in:
Felix Marty 2023-07-05 16:42:37 +00:00
commit 620ed7d8aa
62 changed files with 4623 additions and 2152 deletions

505
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.8.2" version = "0.9.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder FROM kernel-builder as flash-att-builder
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir

View File

@ -43,8 +43,8 @@ to power LLMs api-inference widgets.
- Tensor Parallelism for faster inference on multiple GPUs - Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE) - Token streaming using Server-Sent Events (SSE)
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput - [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures - Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
@ -84,7 +84,7 @@ model=bigscience/bloom-560m
num_shard=2 num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.8 --model-id $model --num-shard $num_shard docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id $model --num-shard $num_shard
``` ```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.

View File

@ -1,15 +0,0 @@
# Azure ML endpoint
## Create all resources
```shell
az ml model create -f model.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```
## Update deployment
```shell
az ml online-deployment update -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```

View File

@ -1,38 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
name: bloom-deployment
endpoint_name: bloom-inference
model: azureml:bloom-safetensors:1
model_mount_path: /var/azureml-model
environment_variables:
WEIGHTS_CACHE_OVERRIDE: /var/azureml-model/bloom-safetensors
MODEL_ID: bigscience/bloom
NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0
inference_config:
liveness_route:
port: 80
path: /health
readiness_route:
port: 80
path: /health
scoring_route:
port: 80
path: /generate
instance_type: Standard_ND96amsr_A100_v4
request_settings:
request_timeout_ms: 90000
max_concurrent_requests_per_instance: 256
liveness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
readiness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
instance_count: 1

View File

@ -1,3 +0,0 @@
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
name: bloom-inference
auth_mode: key

View File

@ -1,3 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
name: bloom-safetensors
path: /data/bloom-safetensors

View File

@ -14,36 +14,85 @@ use tracing_subscriber::EnvFilter;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
/// The name of the tokenizer (as in model_id on the huggingface hub, or local path).
#[clap(short, long, env)] #[clap(short, long, env)]
tokenizer_name: String, tokenizer_name: String,
/// The revision to use for the tokenizer if on the hub.
#[clap(default_value = "main", long, env)] #[clap(default_value = "main", long, env)]
revision: String, revision: String,
/// The various batch sizes to benchmark for, the idea is to get enough
/// batching to start seeing increased latency, this usually means you're
/// moving from memory bound (usual as BS=1) to compute bound, and this is
/// a sweet spot for the maximum batch size for the model under test
#[clap(short, long)] #[clap(short, long)]
batch_size: Option<Vec<u32>>, batch_size: Option<Vec<u32>>,
/// This is the initial prompt sent to the text-generation-server length
/// in token. Longer prompt will slow down the benchmark. Usually the
/// latency grows somewhat linearly with this for the prefill step.
///
/// Most importantly, the prefill step is usually not the one dominating
/// your runtime, so it's ok to keep it short.
#[clap(default_value = "10", short, long, env)] #[clap(default_value = "10", short, long, env)]
sequence_length: u32, sequence_length: u32,
/// This is how many tokens will be generated by the server and averaged out
/// to give the `decode` latency. This is the *critical* number you want to optimize for
/// LLM spend most of their time doing decoding.
///
/// Decode latency is usually quite stable.
#[clap(default_value = "8", short, long, env)] #[clap(default_value = "8", short, long, env)]
decode_length: u32, decode_length: u32,
///How many runs should we average from
#[clap(default_value = "10", short, long, env)] #[clap(default_value = "10", short, long, env)]
runs: usize, runs: usize,
/// Number of warmup cycles
#[clap(default_value = "1", short, long, env)] #[clap(default_value = "1", short, long, env)]
warmups: usize, warmups: usize,
#[clap(long, env)]
temperature: Option<f32>, /// The location of the grpc socket. This benchmark tool bypasses the router
#[clap(long, env)] /// completely and directly talks to the gRPC processes
top_k: Option<u32>,
#[clap(long, env)]
top_p: Option<f32>,
#[clap(long, env)]
typical_p: Option<f32>,
#[clap(long, env)]
repetition_penalty: Option<f32>,
#[clap(long, env)]
watermark: bool,
#[clap(long, env)]
do_sample: bool,
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)] #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
temperature: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_k: Option<u32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_p: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
typical_p: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
watermark: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
do_sample: bool,
} }
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.8.2" "version": "0.9.0"
}, },
"paths": { "paths": {
"/": { "/": {
@ -270,6 +270,35 @@
} }
} }
}, },
"/health": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Health check method",
"description": "Health check method",
"operationId": "health",
"responses": {
"200": {
"description": "Everything is working fine"
},
"503": {
"description": "Text generation inference is down",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "unhealthy",
"error_type": "healthcheck"
}
}
}
}
}
}
},
"/info": { "/info": {
"get": { "get": {
"tags": [ "tags": [

View File

@ -0,0 +1,140 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5117188,
"text": " is"
},
{
"id": 18147,
"logprob": -8.96875,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.953125,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.94189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5830078,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3105469,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.3215332,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5566406,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.6074219,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.6923828,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5263672,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.8544922,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6118164,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.055877686,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0537109,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.0115737915,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9111328,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4589844,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.4853516,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021636963,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
}

View File

@ -0,0 +1,562 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5117188,
"text": " is"
},
{
"id": 18147,
"logprob": -8.96875,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.953125,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.94189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5830078,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3183594,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.32617188,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.6015625,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.67822266,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5395508,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.8623047,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6020508,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.0552063,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0742188,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011405945,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9165039,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4501953,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.4960938,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.02116394,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
}
]

View File

@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client return flash_neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot): async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate( response = await flash_neox.generate(
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -0,0 +1,48 @@
import pytest
@pytest.fixture(scope="module")
def mpt_sharded_handle(launcher):
with launcher("mosaicml/mpt-7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mpt_sharded(mpt_sharded_handle):
await mpt_sharded_handle.health(300)
return mpt_sharded_handle.client
@pytest.mark.asyncio
async def test_mpt(mpt_sharded, response_snapshot):
response = await mpt_sharded.generate(
"What is Deep Learning?",
max_new_tokens=17,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert (
response.generated_text
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
responses = await generate_load(
mpt_sharded,
"What is Deep Learning?",
max_new_tokens=17,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert (
responses[0].generated_text
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
)
assert responses == response_snapshot

View File

@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::Arc; use std::sync::{mpsc, Arc};
use std::sync::{mpsc, Mutex};
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -40,6 +39,25 @@ impl std::fmt::Display for Quantization {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
Float16,
BFloat16,
}
impl std::fmt::Display for Dtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Dtype::Float16 => {
write!(f, "float16")
}
Dtype::BFloat16 => {
write!(f, "bfloat16")
}
}
}
}
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -76,6 +94,10 @@ struct Args {
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
quantize: Option<Quantization>, quantize: Option<Quantization>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been /// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision. /// contributed in a newer revision.
@ -106,7 +128,7 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle. /// Please note that some models have a finite range of sequence they can handle.
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
/// This is the most important value to set as it defines the "memory budget" /// This is the most important value to set as it defines the "memory budget"
@ -117,15 +139,9 @@ struct Args {
/// `1511` max_new_tokens. /// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM /// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be. /// and the less effective batching can be.
#[clap(default_value = "1512", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
/// The maximum allowed batch size during dynamic batching.
/// Using `max_batch_total_tokens` should be favored in general
/// as it's a finer way to control RAM usage.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where /// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting /// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch. /// ones into the same batch.
@ -139,6 +155,12 @@ struct Args {
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage /// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware. /// of the available hardware.
/// ///
@ -151,19 +173,12 @@ struct Args {
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens. /// or a single query of `1000` tokens.
/// ///
/// So you don't have to control that finely
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
/// want maximum flexibility. However, for your users if they are asking for the full amount of
/// total tokens, they are likely to wait for a very long time to get a spot
/// in the batch (since they are going to be alone) so setting `max_batch_size`
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
///
/// Overall this number should be the largest possible amount that fits the /// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead /// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention /// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number /// or the model implementation, text-generation-inference cannot infer this number
/// automatically. /// automatically.
#[clap(default_value = "32000", long, env)] #[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
/// This setting defines how many tokens can be passed before forcing the waiting /// This setting defines how many tokens can be passed before forcing the waiting
@ -185,9 +200,9 @@ struct Args {
/// for end users. /// for end users.
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
/// The port to listen on. /// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
/// The name of the socket for gRPC communication between the webserver /// The name of the socket for gRPC communication between the webserver
@ -262,7 +277,7 @@ struct Args {
#[derive(Debug)] #[derive(Debug)]
enum ShardStatus { enum ShardStatus {
Ready, Ready,
Failed((usize, String)), Failed((usize, Option<String>)),
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -270,6 +285,7 @@ fn shard_manager(
model_id: String, model_id: String,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -283,7 +299,7 @@ fn shard_manager(
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
) { ) {
// Get UDS path // Get UDS path
@ -319,6 +335,11 @@ fn shard_manager(
shard_argv.push(quantize.to_string()) shard_argv.push(quantize.to_string())
} }
if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string());
shard_argv.push(dtype.to_string())
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_argv.push("--revision".to_string());
@ -334,6 +355,12 @@ fn shard_manager(
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars // Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into())); env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into())); env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -409,20 +436,20 @@ fn shard_manager(
} }
} }
status_sender status_sender
.send(ShardStatus::Failed((rank, err.to_string()))) .send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap(); .unwrap();
return; return;
} }
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap(); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
thread::spawn(move || { thread::spawn(move || {
// Enter shard-manager tracing span // Enter shard-manager tracing span
let stdout = BufReader::new(shard_stdout);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() { for line in shard_stdout_reader.lines() {
// Parse loguru logs // Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) { if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace(); log.trace();
@ -436,8 +463,22 @@ fn shard_manager(
loop { loop {
// Process exited // Process exited
if let Some(exit_status) = p.poll() { if let Some(exit_status) = p.poll() {
let mut err = String::new(); // We read stderr in another thread as it seems that `read_to_string` can block
p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); // indefinitely in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
let mut err = String::new();
shard_stderr_reader.read_to_string(&mut err).unwrap();
err_sender.send(err).unwrap_or(());
});
let err = err_receiver
.recv_timeout(Duration::from_millis(100))
.map_err(|err| {
tracing::error!("Unable to read shard {rank} error from stderr");
err
})
.ok();
if let ExitStatus::Signaled(signal) = exit_status { if let ExitStatus::Signaled(signal) = exit_status {
tracing::error!("Shard process was signaled to shutdown with signal {signal}"); tracing::error!("Shard process was signaled to shutdown with signal {signal}");
@ -450,8 +491,8 @@ fn shard_manager(
} }
// We received a shutdown signal // We received a shutdown signal
if *shutdown.lock().unwrap() { if shutdown.load(Ordering::SeqCst) {
p.terminate().unwrap(); p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard {rank} terminated");
return; return;
@ -470,14 +511,11 @@ fn shard_manager(
} }
} }
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) { fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards"); tracing::info!("Shutting down shards");
// Update shutdown value to true // Update shutdown value to true
// This will be picked up by the shard manager // This will be picked up by the shard manager
{ shutdown.store(true, Ordering::SeqCst);
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown // Wait for shards to shutdown
// This will block till all shutdown_sender are dropped // This will block till all shutdown_sender are dropped
@ -719,7 +757,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
fn spawn_shards( fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
status_receiver: &mpsc::Receiver<ShardStatus>, status_receiver: &mpsc::Receiver<ShardStatus>,
@ -749,6 +787,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize; let quantize = args.quantize;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
@ -759,6 +798,7 @@ fn spawn_shards(
model_id, model_id,
revision, revision,
quantize, quantize,
dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,
@ -793,7 +833,10 @@ fn spawn_shards(
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err); tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
shutdown_shards(shutdown, shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart); return Err(LauncherError::ShardCannotStart);
} }
@ -809,7 +852,7 @@ fn spawn_shards(
fn spawn_webserver( fn spawn_webserver(
args: Args, args: Args,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Popen, LauncherError> { ) -> Result<Popen, LauncherError> {
// All shard started // All shard started
@ -827,6 +870,10 @@ fn spawn_webserver(
args.max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
args.max_total_tokens.to_string(), args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -839,15 +886,6 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Deprecate max_batch_size
if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
argv.push("--revision".to_string()); argv.push("--revision".to_string());
@ -981,7 +1019,7 @@ fn main() -> Result<(), LauncherError> {
download_convert_model(&args, running.clone())?; download_convert_model(&args, running.clone())?;
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel // Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped // When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel(); let (shutdown_sender, shutdown_receiver) = mpsc::channel();
@ -1006,14 +1044,21 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?; let mut webserver =
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code // Default exit code
let mut exit_code = Ok(()); let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}"); tracing::error!("Shard {rank} crashed");
if let Some(err) = err {
tracing::error!("{err}");
}
exit_code = Err(LauncherError::ShardFailed); exit_code = Err(LauncherError::ShardFailed);
break; break;
}; };

View File

@ -11,6 +11,8 @@ service TextGenerationService {
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch /// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token /// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
@ -192,3 +194,13 @@ message DecodeResponse {
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
} }
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}

View File

@ -22,11 +22,11 @@ text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
flume = "0.10.14" flume = "0.10.14"
futures = "0.3.26" futures = "0.3.26"
metrics = "0.20.1" metrics = "0.21.0"
metrics-exporter-prometheus = { version = "0.11.0", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.12.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] } reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152" serde = "1.0.152"
@ -36,7 +36,7 @@ tokenizers = "0.13.3"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tower-http = { version = "0.4.0", features = ["cors"] } tower-http = { version = "0.4.0", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -11,10 +11,10 @@ grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.11" prost = "^0.11"
thiserror = "^1.0" thiserror = "^1.0"
tokio = { version = "^1.25", features = ["sync"] } tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8" tonic = "^0.9"
tower = "^0.4" tower = "^0.4"
tracing = "^0.1" tracing = "^0.1"
[build-dependencies] [build-dependencies]
tonic-build = "0.8.4" tonic-build = "0.9.2"
prost-build = "0.11.6" prost-build = "0.11.6"

View File

@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -94,6 +95,63 @@ impl Client {
Ok(filtered_batch.batch) Ok(filtered_batch.batch)
} }
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -87,6 +87,27 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()
} }
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
opentelemetry = "0.18.0" opentelemetry = "^0.19"
tonic = "^0.8" tonic = "^0.9"
tracing = "^0.1" tracing = "^0.1"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "^0.19"

View File

@ -45,6 +45,7 @@ impl Infer {
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
@ -61,6 +62,7 @@ impl Infer {
tokio::spawn(batching_task( tokio::spawn(batching_task(
client, client,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
queue.clone(), queue.clone(),
@ -240,9 +242,11 @@ impl Infer {
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
@ -257,8 +261,9 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = while let Some((mut entries, batch, span)) = queue
queue.next_batch(None, max_batch_total_tokens).await .next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span) .instrument(span)
@ -284,11 +289,12 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
}; };
let token_budget = max_batch_total_tokens - batch_max_tokens; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = if let Some((mut new_entries, new_batch, span)) = queue
queue.next_batch(min_size, token_budget).await .next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {

View File

@ -28,15 +28,15 @@ struct Args {
max_best_of: usize, max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
max_batch_size,
waiting_served_ratio, waiting_served_ratio,
mut max_batch_total_tokens, max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
port, port,
master_shard_uds_path, master_shard_uds_path,
@ -97,8 +97,18 @@ fn main() -> Result<(), std::io::Error> {
ngrok_password, ngrok_password,
} = args; } = args;
// Validate args
if max_input_length as u32 > max_batch_prefill_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
}
if max_total_tokens as u32 > max_batch_total_tokens {
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
}
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("`validation_workers` must be > 0");
} }
// CORS allowed origins // CORS allowed origins
@ -141,12 +151,6 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!( tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}" "Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -161,10 +165,16 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { false => get_model_info(&tokenizer_name, &revision, authorization_token)
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); .await
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } .unwrap_or_else(|| {
}), tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
}; };
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
@ -190,6 +200,17 @@ fn main() -> Result<(), std::io::Error> {
.info() .info()
.await .await
.expect("Unable to get shard info"); .expect("Unable to get shard info");
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost // Binds on localhost
@ -206,6 +227,7 @@ fn main() -> Result<(), std::io::Error> {
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
@ -219,7 +241,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username, ngrok_username,
ngrok_password, ngrok_password,
) )
.await; .await;
Ok(()) Ok(())
}) })
} }

View File

@ -58,6 +58,7 @@ impl Queue {
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
) -> Option<NextBatch> { ) -> Option<NextBatch> {
// Create response channel // Create response channel
@ -67,6 +68,7 @@ impl Queue {
self.queue_sender self.queue_sender
.send(QueueCommand::NextBatch { .send(QueueCommand::NextBatch {
min_size, min_size,
prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
span: Span::current(), span: Span::current(),
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
prefill_token_budget,
token_budget, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget); let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }),
@ -140,7 +143,12 @@ impl State {
} }
// Get the next batch // Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> { fn next_batch(
&mut self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
return None; return None;
} }
@ -184,7 +192,9 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget { if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
{
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
@ -259,6 +269,7 @@ enum QueueCommand {
Append(Box<Entry>, Span), Append(Box<Entry>, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32, token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span, span: Span,
@ -328,8 +339,8 @@ mod tests {
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false); let mut state = State::new(false);
assert!(state.next_batch(None, 1).is_none()); assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none());
} }
#[test] #[test]
@ -340,7 +351,7 @@ mod tests {
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -356,7 +367,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none()); assert!(state.next_batch(Some(2), 2, 2).is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
@ -372,7 +383,7 @@ mod tests {
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -385,7 +396,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -408,8 +419,8 @@ mod tests {
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false); let queue = Queue::new(false);
assert!(queue.next_batch(None, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
} }
#[tokio::test] #[tokio::test]
@ -420,7 +431,7 @@ mod tests {
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -433,11 +444,11 @@ mod tests {
queue.append(entry3); queue.append(entry3);
// Not enough requests pending // Not enough requests pending
assert!(queue.next_batch(Some(2), 2).await.is_none()); assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
// Not enough token budget // Not enough token budget
assert!(queue.next_batch(Some(1), 0).await.is_none()); assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
// Ok // Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
assert_eq!(entries2.len(), 1); assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2)); assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some()); assert!(entries2.get(&2).unwrap().batch_time.is_some());
@ -453,7 +464,7 @@ mod tests {
queue.append(entry1); queue.append(entry1);
queue.append(entry2); queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -462,7 +473,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
queue.append(entry3); queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -476,6 +487,6 @@ mod tests {
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
} }
} }

View File

@ -514,6 +514,7 @@ pub async fn run(
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
@ -531,6 +532,7 @@ pub async fn run(
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
health,
get_model_info, get_model_info,
compat_generate, compat_generate,
generate, generate,
@ -582,6 +584,7 @@ pub async fn run(
client, client,
validation, validation,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,

View File

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "1.69.0" channel = "1.70.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -1,4 +1,5 @@
include Makefile-flash-att include Makefile-flash-att
include Makefile-vllm
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests

View File

@ -1,9 +1,9 @@
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash-attention: flash-attention:
# Clone flash attention # Clone flash attention
pip install packaging pip install packaging
git clone https://github.com/OlivierDehaene/flash-attention.git git clone https://github.com/HazyResearch/flash-attention.git
build-flash-attention: flash-attention build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit) cd flash-attention && git fetch && git checkout $(flash_att_commit)

13
server/Makefile-vllm Normal file
View File

@ -0,0 +1,13 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
cd vllm && python setup.py build
install-vllm: build-vllm
pip uninstall vllm -y || true
cd vllm && python setup.py install

1904
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "0.8.2" version = "0.9.0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -26,7 +26,8 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "0.13.3" tokenizers = "0.13.3"
huggingface-hub = "^0.14.1" huggingface-hub = "^0.14.1"
transformers = "^4.29.2" transformers = "4.29.2"
einops = "^0.6.1"
[tool.poetry.extras] [tool.poetry.extras]
accelerate = ["accelerate"] accelerate = ["accelerate"]

View File

@ -1,21 +1,23 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0" googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0" grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0" grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0" idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0" numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
@ -26,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0" protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0" pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0" regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0" transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0" urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -22,7 +22,9 @@ class Cache:
del batch del batch
def clear(self): def clear(self):
self.cache.clear() keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self): def __len__(self):
return len(self.cache.keys()) return len(self.cache.keys())

View File

@ -17,12 +17,18 @@ class Quantization(str, Enum):
gptq_cuda = "gptq-cuda" gptq_cuda = "gptq-cuda"
class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -65,7 +71,14 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) dtype = None if dtype is None else dtype.value
if dtype is not None and quantize is not None:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
)
@app.command() @app.command()

View File

@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPTSharded from text_generation_server.models.opt import OPTSharded
@ -100,11 +101,25 @@ def get_model(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None:
dtype = torch.float16
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
dtypetrust_remote_code=trust_remote_code,
) )
if model_id.startswith("bigcode/"): if model_id.startswith("bigcode/"):
@ -113,6 +128,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -124,6 +140,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -138,6 +155,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -149,11 +167,20 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "bloom": if model_type == "bloom":
return BLOOMSharded( return BLOOMSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "mpt":
return MPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
) )
@ -163,6 +190,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -170,6 +198,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
@ -177,6 +206,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -186,6 +216,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -195,6 +226,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -210,6 +242,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError( raise NotImplementedError(
@ -221,6 +254,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
@ -228,12 +262,17 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "opt": elif model_type == "opt":
return OPTSharded( return OPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
elif model_type == "t5": elif model_type == "t5":
@ -241,6 +280,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -253,11 +293,19 @@ def get_model(
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM( return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
@ -267,6 +315,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
@ -274,6 +323,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -454,11 +454,12 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -23,12 +23,16 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix=f"{prefix}.rotary_emb", weights=weights prefix=f"{prefix}.rotary_emb", weights=weights
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size**-0.5
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
@ -122,20 +126,21 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -144,23 +149,23 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill vllm_cache_ops.reshape_and_cache(
if prefill: qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = qkv[:, 1:]
# output # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
start_seq, cu_seqlen_prefill,
end_seq, cu_seqlen_prefill,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -173,31 +178,19 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0] # kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
# Add present to the layer_past tensor at the correct indices block_size = kv_cache[1].shape[3]
layer_past[past_present_indices] = qkv[:, 1:] vllm_attention_ops.single_query_cached_kv_attention(
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
attn_output, attn_output,
start_seq_q, qkv[:, 0],
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -265,14 +258,12 @@ class FlashLlamaLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -281,14 +272,12 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states, normed_hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
# faster post attention rms norm # faster post attention rms norm
@ -333,40 +322,17 @@ class FlashLlamaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values=None, ) -> torch.Tensor:
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -380,34 +346,17 @@ class FlashLlamaModel(torch.nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache[i],
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_key_values[:, i],
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
@ -423,31 +372,27 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

View File

@ -25,11 +25,15 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -110,20 +114,21 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -132,23 +137,23 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill vllm_cache_ops.reshape_and_cache(
if prefill: qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = qkv[:, 1:]
# output # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
start_seq, cu_seqlen_prefill,
end_seq, cu_seqlen_prefill,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -161,31 +166,19 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0] # kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
# Add present to the layer_past tensor at the correct indices block_size = kv_cache[1].shape[3]
layer_past[past_present_indices] = qkv[:, 1:] vllm_attention_ops.single_query_cached_kv_attention(
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
attn_output, attn_output,
start_seq_q, qkv[:, 0],
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -250,14 +243,12 @@ class FlashNeoXLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -266,14 +257,12 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states, ln1_hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -292,14 +281,12 @@ class FlashNeoXLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -346,40 +333,17 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values=None, ) -> torch.Tensor:
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
@ -393,34 +357,17 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache[i],
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_key_values[:, i],
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
@ -434,31 +381,27 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.gpt_neox( hidden_states = self.gpt_neox(
input_ids, input_ids,
position_ids, position_ids,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
return logits, present return logits

View File

@ -4,11 +4,15 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -126,19 +130,26 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -156,25 +167,27 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill vllm_cache_ops.reshape_and_cache(
if prefill: kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
# Copy to layer past )
layer_past[...] = kv
# Expand to query shape # output
kv = kv.expand(-1, 2, self.num_heads, self.head_size) attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq, cu_seqlen_prefill,
end_seq, cu_seqlen_prefill,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -187,32 +200,19 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
layer_past[past_present_indices] = kv block_size = kv_cache[1].shape[3]
# Expand to query shape vllm_attention_ops.single_query_cached_kv_attention(
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -264,19 +264,21 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -293,10 +295,19 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if prefill: if cu_seqlen_prefill is not None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape # Expand to query shape
kv = ( kv = (
kv.unsqueeze(2) kv.unsqueeze(2)
@ -304,18 +315,14 @@ class FlashRWLargeAttention(torch.nn.Module):
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
) )
# output
attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output, attn_output,
start_seq, cu_seqlen_prefill,
end_seq, cu_seqlen_prefill,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -328,36 +335,19 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
layer_past[past_present_indices] = kv block_size = kv_cache[1].shape[3]
# Expand to query shape vllm_attention_ops.single_query_cached_kv_attention(
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.dense( return self.dense(
@ -432,14 +422,12 @@ class FlashRWLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -448,14 +436,12 @@ class FlashRWLayer(nn.Module):
ln_hidden_states, ln_hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
@ -472,14 +458,12 @@ class FlashRWLayer(nn.Module):
hidden_states, hidden_states,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -523,14 +507,12 @@ class FlashRWLargeLayer(nn.Module):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
@ -540,14 +522,12 @@ class FlashRWLargeLayer(nn.Module):
ln_attn, ln_attn,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
# MLP. # MLP.
@ -580,11 +560,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = self.h[0].self_attention.num_heads_kv
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
@ -592,11 +568,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = self.h[0].self_attention.num_groups
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
@ -612,38 +584,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values=None, ) -> torch.Tensor:
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
*self.cache_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
@ -657,32 +608,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual, residual,
cos, cos,
sin, sin,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache[i],
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
@ -697,31 +633,27 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

View File

@ -3,10 +3,15 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -96,6 +101,7 @@ def _load_multi_mqa_gptq(
def _load_multi_mqa( def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
if any("c_attn" in k for k in weights.routing.keys()): if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight") slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -239,18 +245,19 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -263,10 +270,15 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if prefill: if cu_seqlen_prefill is not None:
# Copy to layer past
layer_past[...] = key_value
# Expand from 1 to num_heads # Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
@ -278,10 +290,8 @@ class FlashMQAttention(torch.nn.Module):
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
start_seq, cu_seqlen_prefill,
end_seq, cu_seqlen_prefill,
start_seq,
end_seq,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
@ -294,32 +304,19 @@ class FlashMQAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
# Add present to the layer_past tensor at the correct indices # kv_cache[1] => [num_blocks, 1, head_size, block_size]
layer_past[past_present_indices] = key_value block_size = kv_cache[1].shape[3]
# Expand from 1 to num_heads vllm_attention_ops.single_query_cached_kv_attention(
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
start_seq_q, query,
end_seq_q, kv_cache[0],
start_seq, kv_cache[1],
end_seq, self.kv_head_mapping,
1,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False, block_tables,
False, input_lengths,
False, block_size,
0, max_s,
None,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -379,27 +376,23 @@ class Block(nn.Module):
self, self,
hidden_states, hidden_states,
residual, residual,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
layer_past,
past_present_indices,
prefill,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
@ -445,64 +438,36 @@ class FlashSantacoderModel(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values=None, ) -> torch.Tensor:
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
prefill = False
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache[i],
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
) )
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values return hidden_states
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
@ -539,31 +504,27 @@ class FlashSantacoderForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
start_seq, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
start_seq_q, block_tables: torch.Tensor,
end_seq_q, slots: torch.Tensor,
max_s, input_lengths: torch.Tensor,
past_present_indices, max_s: int,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
hidden_states, present = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
start_seq, cu_seqlen_prefill,
end_seq, kv_cache,
start_seq_q, block_tables,
end_seq_q, slots,
input_lengths,
max_s, max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits, present return logits

File diff suppressed because it is too large Load Diff

View File

@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
try: try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError: except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights) self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False encoder_config.is_decoder = False

View File

@ -1,11 +1,14 @@
import math
import itertools
import torch import torch
import torch.distributed import torch.distributed
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -20,6 +23,92 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(self, batch: "FlashCausalLMBatch"):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert (
len(free_block_indices) >= batch.blocks
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
# Slice by the number of required blocks
block_indices = free_block_indices[: batch.blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
batch.slots = torch.concat(slots).to(batch.input_ids.device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -32,23 +121,29 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
# Indices to copy present to the correct indices is the pre-allocated past key values # Flash Attention values
past_present_indices: torch.Tensor
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Paged Attention values
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q: Optional[torch.Tensor]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
max_seqlen: int max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs # Prefill metadata tensors to efficiently compute logprobs
@ -62,6 +157,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
input_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
@ -69,15 +165,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to # Number of blocks in this batch
max_tokens: int blocks: int
# Maximum number of blocks
max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.blocks * BLOCK_SIZE,
) )
@classmethod @classmethod
@ -99,12 +197,10 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
position_ids = [] position_ids = []
past_present_indices = [] cu_seqlen_prefill = [0]
start_seq = [] needed_blocks_slots = []
end_seq = [] start_slots = []
start_seq_prefill = [] slot_indices = []
end_seq_prefill = []
max_seqlen = 0
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -126,7 +222,10 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0 cumulative_max_length = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0 max_length = 0
max_blocks = 0
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -138,7 +237,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
@ -151,10 +249,7 @@ class FlashCausalLMBatch(Batch):
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
start_seq_prefill.append(cumulative_length) cu_seqlen_prefill.append(cumulative_length + input_length)
end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
@ -164,6 +259,21 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -184,22 +294,17 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices)
# Update # Update
# Remove one as the first token des not have a past
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += input_length + max_new_tokens - 1 cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens) max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor # Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros( all_input_ids_tensor = np.zeros(
@ -212,41 +317,32 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1: if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32
)
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0]
past_present_indices = past_present_indices[0] cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
start_seq_prefill = start_seq )
end_seq_prefill = end_seq
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) input_lengths_tensor = torch.tensor(
past_present_indices = torch.tensor( input_lengths, dtype=torch.int32, device=device
past_present_indices, device=device, dtype=torch.int64
) )
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
prefill_next_token_indices = end_seq_prefill - 1 prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs: elif no_prefill_logprobs:
prefill_head_indices = end_seq_prefill - 1 prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None prefill_next_token_indices = None
else: else:
prefill_head_indices = torch.tensor( prefill_head_indices = torch.tensor(
@ -262,26 +358,27 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices, cu_seqlen_prefill=cu_seqlen_prefill,
start_seq=start_seq, start_slots=start_slots,
end_seq=end_seq, slot_indices=slot_indices,
start_seq_prefill=start_seq_prefill, needed_blocks_slots=needed_blocks_slots,
end_seq_prefill=end_seq_prefill, block_tables=None,
start_seq_q=None, block_tables_tensor=None,
end_seq_q=None, slots=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens, prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length, blocks=blocks,
max_blocks=max_blocks,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -294,28 +391,24 @@ class FlashCausalLMBatch(Batch):
device = self.input_ids.device device = self.input_ids.device
# Cumulative length
cumulative_max_length = 0
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
# Used to index into tensors # Used to index into tensors
indices = [] indices = []
# past indices to keep # slots to keep after filtering
past_indices = torch.zeros( slot_filtering_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device self.slots.shape[0], dtype=torch.bool, device=device
) )
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32) slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
max_seqlen = 0 max_seqlen = 0
requests = [] requests = []
start_slots = []
block_tables = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -324,6 +417,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
blocks = 0
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
@ -348,28 +446,51 @@ class FlashCausalLMBatch(Batch):
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
request_block_table = self.block_tables[idx]
blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU) # Copy to tensor (CPU)
start_seq[i] = cumulative_max_length slot_indices[i] = cumulative_max_length + request_input_length - 1
end_seq[i] = cumulative_max_length + request_input_length
# Set slice # Set slice
past_indices[ slot_filtering_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True ] = True
cumulative_max_length += request_input_length + remaining_tokens - 1 cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
global CACHE_MANAGER
block_indices_to_free = []
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
block_indices_to_free.extend(self.block_tables[i])
# Free blocks
CACHE_MANAGER.free(block_indices_to_free)
# Needed to avoid dropping blocks when the batches will go out of scope
self.block_tables = None
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[past_indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device) slot_indices = slot_indices.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -377,26 +498,27 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices, cu_seqlen_prefill=None,
start_seq=start_seq, start_slots=start_slots,
end_seq=end_seq, slot_indices=slot_indices,
start_seq_prefill=None, needed_blocks_slots=None,
end_seq_prefill=None, block_tables=block_tables,
start_seq_q=start_seq_q, block_tables_tensor=block_tables_tensor,
end_seq_q=end_seq_q, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length, blocks=blocks,
max_blocks=max_blocks,
) )
@classmethod @classmethod
@ -406,22 +528,46 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches]) blocks = 0
total_batch_size = 0
dtype = batches[0].past_key_values.dtype total_slots = 0
device = batches[0].input_ids.device max_blocks = 0
max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
)
),
)
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_seq = batches[0].start_seq.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots)
end_seq = batches[0].end_seq.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
start_seq_q = torch.arange( input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
0, total_batch_size, device=device, dtype=torch.int32 total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
) )
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = []
start_slots = []
block_tables = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -433,8 +579,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
max_tokens = 0 cumulative_slots = 0
max_length = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
@ -448,16 +593,27 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
start_seq[start_index:end_index] = batch.start_seq + max_tokens all_input_ids_tensor[
end_seq[start_index:end_index] = batch.end_seq + max_tokens start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
max_seqlen = max(max_seqlen, batch.max_seqlen) block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
@ -466,73 +622,58 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens cumulative_slots += len(batch.slots)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
past_key_values = torch.cat(past_key_values, dim=0) start_slots = torch.concat(start_slots)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
) )
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
past_present_indices=past_present_indices, cu_seqlen_prefill=None,
start_seq=start_seq, start_slots=start_slots,
end_seq=end_seq, slot_indices=slot_indices,
start_seq_prefill=None, needed_blocks_slots=None,
end_seq_prefill=None, block_tables=block_tables,
start_seq_q=start_seq_q, block_tables_tensor=block_tables_tensor,
end_seq_q=end_seq_q, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=max_tokens, blocks=blocks,
max_blocks=max_blocks,
) )
def __del__(self):
if self.block_tables is not None and self.block_tables:
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -540,32 +681,19 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_cls: Type[PreTrainedModel], model: torch.nn.Module,
model_id: str, tokenizer: PreTrainedTokenizerBase,
revision: Optional[str] = None, num_layers: int,
quantize: Optional[str] = None, num_kv_heads: int,
trust_remote_code: bool = False, head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
): ):
if torch.cuda.is_available(): self.num_layers = num_layers
device = torch.device("cuda") self.num_kv_heads = num_kv_heads
dtype = torch.float16 self.head_size = head_size
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
@ -573,12 +701,38 @@ class FlashCausalLM(Model):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@property @property
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
global CACHE_MANAGER
torch.cuda.empty_cache()
try:
CACHE_MANAGER = CacheManager(
# Adds some wiggle room
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.exception(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
raise e
del batch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
@ -588,28 +742,25 @@ class FlashCausalLM(Model):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
start_seq: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor],
end_seq: torch.Tensor, block_tables: torch.Tensor,
start_seq_q: Optional[torch.Tensor], slots: torch.Tensor,
end_seq_q: Optional[torch.Tensor], input_lengths: torch.Tensor,
max_s: int, max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
global CACHE_MANAGER
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
start_seq=start_seq, cu_seqlen_prefill=cu_seqlen_prefill,
end_seq=end_seq, kv_cache=CACHE_MANAGER.kv_cache,
start_seq_q=start_seq_q, block_tables=block_tables,
end_seq_q=end_seq_q, slots=slots,
input_lengths=input_lengths,
max_s=max_s, max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) )
@ -617,31 +768,21 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
if prefill: if batch.needed_blocks_slots:
# Ask to pre-allocate kv to its max size # Allocate blocks to this batch
# == Sum over batch size (number of tokens + max_new_tokens) - batch size CACHE_MANAGER.allocate(batch)
pre_allocate_past_size = batch.max_tokens
start_seq = batch.start_seq_prefill
end_seq = batch.end_seq_prefill
else:
pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
out, present = self.forward( out = self.forward(
batch.input_ids, batch.input_ids,
batch.position_ids, batch.position_ids,
start_seq, batch.cu_seqlen_prefill,
end_seq, batch.block_tables_tensor,
batch.start_seq_q, batch.slots[batch.slot_indices],
batch.end_seq_q, batch.input_lengths_tensor,
batch.max_seqlen, batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices, batch.prefill_head_indices,
) )
@ -662,15 +803,10 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
# We do not need start_seq_prefill and end_seq_prefill anymore batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
batch.start_seq_prefill = None # We do not need cu_seqlen_prefill anymore
batch.end_seq_prefill = None batch.cu_seqlen_prefill = None
else: else:
prefill_logprobs = None prefill_logprobs = None
next_position_ids = batch.position_ids next_position_ids = batch.position_ids
@ -731,8 +867,8 @@ class FlashCausalLM(Model):
# Set values in batch # Set values in batch
batch.input_ids = next_input_ids batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.past_present_indices = batch.end_seq batch.input_lengths_tensor += 1
batch.end_seq = batch.end_seq + 1 batch.slot_indices += 1
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
@ -755,7 +891,6 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
next_token_ids, next_token_ids,
@ -770,7 +905,6 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
all_input_ids_tensor,
do_sample, do_sample,
seed, seed,
next_token_id, next_token_id,
@ -845,19 +979,20 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
new_input_length = input_length + 1
# Update values # Update values
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = input_length + 1
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
return generations, None
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped return generations, batch
return generations, batch if not stopped else None

View File

@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
@ -64,10 +65,12 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
@ -55,10 +56,12 @@ class FlashNeoXSharded(FlashCausalLM):
model = FlashGPTNeoXForCausalLM(config, weights) model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashNeoXSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")
@ -55,10 +56,12 @@ class FlashRWSharded(FlashCausalLM):
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashRWSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
@ -52,18 +53,22 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, filenames,
aliases = {"transformer.wte.weight": ["lm_head.weight"]} device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
) )
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,

View File

@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -22,6 +22,9 @@ class Model(ABC):
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
@ -55,6 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
self.generate_token(batch)
def decode_token( def decode_token(
self, self,
all_input_ids: List[int], all_input_ids: List[int],

View File

@ -0,0 +1,98 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("MPTSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -12,11 +12,12 @@ class RW(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
@ -99,6 +106,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
): ):
@ -107,6 +115,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
@ -121,7 +130,9 @@ def serve(
server_urls = [local_url] server_urls = [local_url]
try: try:
model = get_model(model_id, revision, sharded, quantize, trust_remote_code) model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
)
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
@ -152,4 +163,6 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
)

View File

@ -32,7 +32,19 @@ def load_layer_norm(cls, prefix, weights, eps):
return ln return ln
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = None
return ln
torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module): class FastLinear(nn.Module):
@ -357,7 +369,7 @@ try:
def __init__(self, inv_freq): def __init__(self, inv_freq):
super().__init__() super().__init__()
self.register_buffer("inv_freq", inv_freq) self.inv_freq = inv_freq
self._seq_len_cached = 0 self._seq_len_cached = 0
self._cos_cached = None self._cos_cached = None
self._sin_cached = None self._sin_cached = None

View File

@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep
# Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(

View File

@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds self.seeds = seeds
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None: if self.watermark_processor is not None:

View File

@ -3,8 +3,16 @@ from typing import List, Dict, Optional
from safetensors import safe_open from safetensors import safe_open
import torch import torch
class Weights: class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f: