mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Merge branch 'gptq-cuda-kernels' of https://github.com/fxmarty/text-generation-inference into gptq-cuda-kernels
This commit is contained in:
commit
620ed7d8aa
505
Cargo.lock
generated
505
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.8.2"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
18
Dockerfile
18
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
|
|
||||||
# Build Flash Attention CUDA kernels
|
# Build Flash Attention CUDA kernels
|
||||||
FROM kernel-builder as flash-att-builder
|
FROM kernel-builder as flash-att-builder
|
||||||
|
|
||||||
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
|
|||||||
# Build specific version of transformers
|
# Build specific version of transformers
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
# Build vllm CUDA kernels
|
||||||
|
FROM kernel-builder as vllm-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-vllm Makefile
|
||||||
|
|
||||||
|
# Build specific version of vllm
|
||||||
|
RUN make build-vllm
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
||||||
|
|
||||||
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
|||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from transformers builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||||
|
|
||||||
|
# Copy builds artifacts from vllm builder
|
||||||
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ to power LLMs api-inference widgets.
|
|||||||
- Tensor Parallelism for faster inference on multiple GPUs
|
- Tensor Parallelism for faster inference on multiple GPUs
|
||||||
- Token streaming using Server-Sent Events (SSE)
|
- Token streaming using Server-Sent Events (SSE)
|
||||||
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
|
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
|
||||||
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures
|
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||||
@ -84,7 +84,7 @@ model=bigscience/bloom-560m
|
|||||||
num_shard=2
|
num_shard=2
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.8 --model-id $model --num-shard $num_shard
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id $model --num-shard $num_shard
|
||||||
```
|
```
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||||
|
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
# Azure ML endpoint
|
|
||||||
|
|
||||||
## Create all resources
|
|
||||||
|
|
||||||
```shell
|
|
||||||
az ml model create -f model.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
|
||||||
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
|
||||||
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
|
||||||
```
|
|
||||||
|
|
||||||
## Update deployment
|
|
||||||
|
|
||||||
```shell
|
|
||||||
az ml online-deployment update -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
|
||||||
```
|
|
@ -1,38 +0,0 @@
|
|||||||
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
|
|
||||||
name: bloom-deployment
|
|
||||||
endpoint_name: bloom-inference
|
|
||||||
model: azureml:bloom-safetensors:1
|
|
||||||
model_mount_path: /var/azureml-model
|
|
||||||
environment_variables:
|
|
||||||
WEIGHTS_CACHE_OVERRIDE: /var/azureml-model/bloom-safetensors
|
|
||||||
MODEL_ID: bigscience/bloom
|
|
||||||
NUM_SHARD: 8
|
|
||||||
environment:
|
|
||||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0
|
|
||||||
inference_config:
|
|
||||||
liveness_route:
|
|
||||||
port: 80
|
|
||||||
path: /health
|
|
||||||
readiness_route:
|
|
||||||
port: 80
|
|
||||||
path: /health
|
|
||||||
scoring_route:
|
|
||||||
port: 80
|
|
||||||
path: /generate
|
|
||||||
instance_type: Standard_ND96amsr_A100_v4
|
|
||||||
request_settings:
|
|
||||||
request_timeout_ms: 90000
|
|
||||||
max_concurrent_requests_per_instance: 256
|
|
||||||
liveness_probe:
|
|
||||||
initial_delay: 600
|
|
||||||
timeout: 90
|
|
||||||
period: 120
|
|
||||||
success_threshold: 1
|
|
||||||
failure_threshold: 5
|
|
||||||
readiness_probe:
|
|
||||||
initial_delay: 600
|
|
||||||
timeout: 90
|
|
||||||
period: 120
|
|
||||||
success_threshold: 1
|
|
||||||
failure_threshold: 5
|
|
||||||
instance_count: 1
|
|
@ -1,3 +0,0 @@
|
|||||||
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
|
|
||||||
name: bloom-inference
|
|
||||||
auth_mode: key
|
|
@ -1,3 +0,0 @@
|
|||||||
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
|
|
||||||
name: bloom-safetensors
|
|
||||||
path: /data/bloom-safetensors
|
|
@ -14,36 +14,85 @@ use tracing_subscriber::EnvFilter;
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
/// The name of the tokenizer (as in model_id on the huggingface hub, or local path).
|
||||||
#[clap(short, long, env)]
|
#[clap(short, long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
|
||||||
|
/// The revision to use for the tokenizer if on the hub.
|
||||||
#[clap(default_value = "main", long, env)]
|
#[clap(default_value = "main", long, env)]
|
||||||
revision: String,
|
revision: String,
|
||||||
|
|
||||||
|
/// The various batch sizes to benchmark for, the idea is to get enough
|
||||||
|
/// batching to start seeing increased latency, this usually means you're
|
||||||
|
/// moving from memory bound (usual as BS=1) to compute bound, and this is
|
||||||
|
/// a sweet spot for the maximum batch size for the model under test
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
batch_size: Option<Vec<u32>>,
|
batch_size: Option<Vec<u32>>,
|
||||||
|
|
||||||
|
/// This is the initial prompt sent to the text-generation-server length
|
||||||
|
/// in token. Longer prompt will slow down the benchmark. Usually the
|
||||||
|
/// latency grows somewhat linearly with this for the prefill step.
|
||||||
|
///
|
||||||
|
/// Most importantly, the prefill step is usually not the one dominating
|
||||||
|
/// your runtime, so it's ok to keep it short.
|
||||||
#[clap(default_value = "10", short, long, env)]
|
#[clap(default_value = "10", short, long, env)]
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
|
|
||||||
|
/// This is how many tokens will be generated by the server and averaged out
|
||||||
|
/// to give the `decode` latency. This is the *critical* number you want to optimize for
|
||||||
|
/// LLM spend most of their time doing decoding.
|
||||||
|
///
|
||||||
|
/// Decode latency is usually quite stable.
|
||||||
#[clap(default_value = "8", short, long, env)]
|
#[clap(default_value = "8", short, long, env)]
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
|
||||||
|
///How many runs should we average from
|
||||||
#[clap(default_value = "10", short, long, env)]
|
#[clap(default_value = "10", short, long, env)]
|
||||||
runs: usize,
|
runs: usize,
|
||||||
|
|
||||||
|
/// Number of warmup cycles
|
||||||
#[clap(default_value = "1", short, long, env)]
|
#[clap(default_value = "1", short, long, env)]
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
#[clap(long, env)]
|
|
||||||
temperature: Option<f32>,
|
/// The location of the grpc socket. This benchmark tool bypasses the router
|
||||||
#[clap(long, env)]
|
/// completely and directly talks to the gRPC processes
|
||||||
top_k: Option<u32>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
top_p: Option<f32>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
typical_p: Option<f32>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
repetition_penalty: Option<f32>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
watermark: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
do_sample: bool,
|
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
top_k: Option<u32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
typical_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark: bool,
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
do_sample: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.8.2"
|
"version": "0.9.0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
@ -270,6 +270,35 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/health": {
|
||||||
|
"get": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Health check method",
|
||||||
|
"description": "Health check method",
|
||||||
|
"operationId": "health",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Everything is working fine"
|
||||||
|
},
|
||||||
|
"503": {
|
||||||
|
"description": "Text generation inference is down",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "unhealthy",
|
||||||
|
"error_type": "healthcheck"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/info": {
|
"/info": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": [
|
"tags": [
|
||||||
|
140
integration-tests/models/__snapshots__/test_mpt/test_mpt.json
Normal file
140
integration-tests/models/__snapshots__/test_mpt/test_mpt.json
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 17,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -1.5117188,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -8.96875,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -1.953125,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.94189453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 428,
|
||||||
|
"logprob": -1.5830078,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -3.3105469,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.3215332,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.5566406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.6074219,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.69628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.6923828,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.5263672,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 749,
|
||||||
|
"logprob": -1.8544922,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3423,
|
||||||
|
"logprob": -0.6118164,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.055877686,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.0537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.0115737915,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.9111328,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4648,
|
||||||
|
"logprob": -1.4589844,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13345,
|
||||||
|
"logprob": -1.4853516,
|
||||||
|
"special": false,
|
||||||
|
"text": " artificial"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11454,
|
||||||
|
"logprob": -0.021636963,
|
||||||
|
"special": false,
|
||||||
|
"text": " neural"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
}
|
@ -0,0 +1,562 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 17,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -1.5117188,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -8.96875,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -1.953125,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.94189453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 428,
|
||||||
|
"logprob": -1.5830078,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -3.3183594,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.32617188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.6015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.69628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.67822266,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.5395508,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 749,
|
||||||
|
"logprob": -1.8623047,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3423,
|
||||||
|
"logprob": -0.6020508,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0552063,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.0742188,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.011405945,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.9165039,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4648,
|
||||||
|
"logprob": -1.4501953,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13345,
|
||||||
|
"logprob": -1.4960938,
|
||||||
|
"special": false,
|
||||||
|
"text": " artificial"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11454,
|
||||||
|
"logprob": -0.02116394,
|
||||||
|
"special": false,
|
||||||
|
"text": " neural"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 17,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -1.5,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -8.984375,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -1.96875,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.93359375,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 428,
|
||||||
|
"logprob": -1.5800781,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -3.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.31835938,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.5644531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.5957031,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.69628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.68603516,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.5258789,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 749,
|
||||||
|
"logprob": -1.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3423,
|
||||||
|
"logprob": -0.6166992,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.056762695,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.011428833,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.9213867,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4648,
|
||||||
|
"logprob": -1.4726562,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13345,
|
||||||
|
"logprob": -1.5039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " artificial"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11454,
|
||||||
|
"logprob": -0.021652222,
|
||||||
|
"special": false,
|
||||||
|
"text": " neural"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 17,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -1.5,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -8.984375,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -1.96875,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.93359375,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 428,
|
||||||
|
"logprob": -1.5800781,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -3.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.31835938,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.5644531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.5957031,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.69628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.68603516,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.5258789,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 749,
|
||||||
|
"logprob": -1.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3423,
|
||||||
|
"logprob": -0.6166992,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.056762695,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.011428833,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.9213867,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4648,
|
||||||
|
"logprob": -1.4726562,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13345,
|
||||||
|
"logprob": -1.5039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " artificial"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11454,
|
||||||
|
"logprob": -0.021652222,
|
||||||
|
"special": false,
|
||||||
|
"text": " neural"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 17,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -1.5,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -8.984375,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -1.96875,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.93359375,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 428,
|
||||||
|
"logprob": -1.5800781,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -3.3242188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.31835938,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.5644531,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.5957031,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -0.69628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.68603516,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.5258789,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 749,
|
||||||
|
"logprob": -1.859375,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3423,
|
||||||
|
"logprob": -0.6166992,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.056762695,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.011428833,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.9213867,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4648,
|
||||||
|
"logprob": -1.4726562,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13345,
|
||||||
|
"logprob": -1.5039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " artificial"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11454,
|
||||||
|
"logprob": -0.021652222,
|
||||||
|
"special": false,
|
||||||
|
"text": " neural"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
}
|
||||||
|
]
|
@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
|
|||||||
return flash_neox_handle.client
|
return flash_neox_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox(flash_neox, response_snapshot):
|
async def test_flash_neox(flash_neox, response_snapshot):
|
||||||
response = await flash_neox.generate(
|
response = await flash_neox.generate(
|
||||||
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
48
integration-tests/models/test_mpt.py
Normal file
48
integration-tests/models/test_mpt.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mpt_sharded_handle(launcher):
|
||||||
|
with launcher("mosaicml/mpt-7b", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def mpt_sharded(mpt_sharded_handle):
|
||||||
|
await mpt_sharded_handle.health(300)
|
||||||
|
return mpt_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mpt(mpt_sharded, response_snapshot):
|
||||||
|
response = await mpt_sharded.generate(
|
||||||
|
"What is Deep Learning?",
|
||||||
|
max_new_tokens=17,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 17
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
mpt_sharded,
|
||||||
|
"What is Deep Learning?",
|
||||||
|
max_new_tokens=17,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::mpsc::TryRecvError;
|
use std::sync::mpsc::TryRecvError;
|
||||||
use std::sync::Arc;
|
use std::sync::{mpsc, Arc};
|
||||||
use std::sync::{mpsc, Mutex};
|
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@ -40,6 +39,25 @@ impl std::fmt::Display for Quantization {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Dtype {
|
||||||
|
Float16,
|
||||||
|
BFloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Dtype {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
Dtype::Float16 => {
|
||||||
|
write!(f, "float16")
|
||||||
|
}
|
||||||
|
Dtype::BFloat16 => {
|
||||||
|
write!(f, "bfloat16")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -76,6 +94,10 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
dtype: Option<Dtype>,
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
/// contributed in a newer revision.
|
/// contributed in a newer revision.
|
||||||
@ -106,7 +128,7 @@ struct Args {
|
|||||||
/// for users. The larger this value, the longer prompt users can send which
|
/// for users. The larger this value, the longer prompt users can send which
|
||||||
/// can impact the overall memory required to handle the load.
|
/// can impact the overall memory required to handle the load.
|
||||||
/// Please note that some models have a finite range of sequence they can handle.
|
/// Please note that some models have a finite range of sequence they can handle.
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1024", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
|
||||||
/// This is the most important value to set as it defines the "memory budget"
|
/// This is the most important value to set as it defines the "memory budget"
|
||||||
@ -117,15 +139,9 @@ struct Args {
|
|||||||
/// `1511` max_new_tokens.
|
/// `1511` max_new_tokens.
|
||||||
/// The larger this value, the larger amount each request will be in your RAM
|
/// The larger this value, the larger amount each request will be in your RAM
|
||||||
/// and the less effective batching can be.
|
/// and the less effective batching can be.
|
||||||
#[clap(default_value = "1512", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
|
||||||
/// The maximum allowed batch size during dynamic batching.
|
|
||||||
/// Using `max_batch_total_tokens` should be favored in general
|
|
||||||
/// as it's a finer way to control RAM usage.
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
|
|
||||||
/// This represents the ratio of waiting queries vs running queries where
|
/// This represents the ratio of waiting queries vs running queries where
|
||||||
/// you want to start considering pausing the running queries to include the waiting
|
/// you want to start considering pausing the running queries to include the waiting
|
||||||
/// ones into the same batch.
|
/// ones into the same batch.
|
||||||
@ -139,6 +155,12 @@ struct Args {
|
|||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
|
||||||
|
/// Limits the number of tokens for the prefill operation.
|
||||||
|
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||||
|
/// to limit the number of requests that can be sent.
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
|
||||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||||
/// of the available hardware.
|
/// of the available hardware.
|
||||||
///
|
///
|
||||||
@ -151,19 +173,12 @@ struct Args {
|
|||||||
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
||||||
/// or a single query of `1000` tokens.
|
/// or a single query of `1000` tokens.
|
||||||
///
|
///
|
||||||
/// So you don't have to control that finely
|
|
||||||
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
|
|
||||||
/// want maximum flexibility. However, for your users if they are asking for the full amount of
|
|
||||||
/// total tokens, they are likely to wait for a very long time to get a spot
|
|
||||||
/// in the batch (since they are going to be alone) so setting `max_batch_size`
|
|
||||||
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
|
|
||||||
///
|
|
||||||
/// Overall this number should be the largest possible amount that fits the
|
/// Overall this number should be the largest possible amount that fits the
|
||||||
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
||||||
/// depends on other parameters like if you're using quantization, flash attention
|
/// depends on other parameters like if you're using quantization, flash attention
|
||||||
/// or the model implementation, text-generation-inference cannot infer this number
|
/// or the model implementation, text-generation-inference cannot infer this number
|
||||||
/// automatically.
|
/// automatically.
|
||||||
#[clap(default_value = "32000", long, env)]
|
#[clap(default_value = "16000", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
|
||||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||||
@ -185,9 +200,9 @@ struct Args {
|
|||||||
/// for end users.
|
/// for end users.
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
|
|
||||||
/// The port to listen on.
|
/// The port to listen on.
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// The name of the socket for gRPC communication between the webserver
|
/// The name of the socket for gRPC communication between the webserver
|
||||||
@ -262,7 +277,7 @@ struct Args {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ShardStatus {
|
enum ShardStatus {
|
||||||
Ready,
|
Ready,
|
||||||
Failed((usize, String)),
|
Failed((usize, Option<String>)),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -270,6 +285,7 @@ fn shard_manager(
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
dtype: Option<Dtype>,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
@ -283,7 +299,7 @@ fn shard_manager(
|
|||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<AtomicBool>,
|
||||||
_shutdown_sender: mpsc::Sender<()>,
|
_shutdown_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
// Get UDS path
|
// Get UDS path
|
||||||
@ -319,6 +335,11 @@ fn shard_manager(
|
|||||||
shard_argv.push(quantize.to_string())
|
shard_argv.push(quantize.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(dtype) = dtype {
|
||||||
|
shard_argv.push("--dtype".to_string());
|
||||||
|
shard_argv.push(dtype.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = revision {
|
if let Some(revision) = revision {
|
||||||
shard_argv.push("--revision".to_string());
|
shard_argv.push("--revision".to_string());
|
||||||
@ -334,6 +355,12 @@ fn shard_manager(
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// Use cuda allocator. It leads to less memory fragmentation
|
||||||
|
env.push((
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF".into(),
|
||||||
|
"backend:cudaMallocAsync".into(),
|
||||||
|
));
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
env.push(("RANK".into(), rank.to_string().into()));
|
env.push(("RANK".into(), rank.to_string().into()));
|
||||||
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
@ -409,20 +436,20 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
status_sender
|
status_sender
|
||||||
.send(ShardStatus::Failed((rank, err.to_string())))
|
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Redirect STDOUT to the console
|
// Redirect STDOUT to the console
|
||||||
let shard_stdout = p.stdout.take().unwrap();
|
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||||
|
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
// Enter shard-manager tracing span
|
// Enter shard-manager tracing span
|
||||||
let stdout = BufReader::new(shard_stdout);
|
|
||||||
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
|
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
|
||||||
for line in stdout.lines() {
|
for line in shard_stdout_reader.lines() {
|
||||||
// Parse loguru logs
|
// Parse loguru logs
|
||||||
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
||||||
log.trace();
|
log.trace();
|
||||||
@ -436,8 +463,22 @@ fn shard_manager(
|
|||||||
loop {
|
loop {
|
||||||
// Process exited
|
// Process exited
|
||||||
if let Some(exit_status) = p.poll() {
|
if let Some(exit_status) = p.poll() {
|
||||||
|
// We read stderr in another thread as it seems that `read_to_string` can block
|
||||||
|
// indefinitely in some cases
|
||||||
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
thread::spawn(move || {
|
||||||
let mut err = String::new();
|
let mut err = String::new();
|
||||||
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
shard_stderr_reader.read_to_string(&mut err).unwrap();
|
||||||
|
err_sender.send(err).unwrap_or(());
|
||||||
|
});
|
||||||
|
|
||||||
|
let err = err_receiver
|
||||||
|
.recv_timeout(Duration::from_millis(100))
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("Unable to read shard {rank} error from stderr");
|
||||||
|
err
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
|
||||||
if let ExitStatus::Signaled(signal) = exit_status {
|
if let ExitStatus::Signaled(signal) = exit_status {
|
||||||
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
||||||
@ -450,8 +491,8 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if *shutdown.lock().unwrap() {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.terminate().unwrap();
|
p.kill().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
tracing::info!("Shard {rank} terminated");
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
@ -470,14 +511,11 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
|
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
|
||||||
tracing::info!("Shutting down shards");
|
tracing::info!("Shutting down shards");
|
||||||
// Update shutdown value to true
|
// Update shutdown value to true
|
||||||
// This will be picked up by the shard manager
|
// This will be picked up by the shard manager
|
||||||
{
|
shutdown.store(true, Ordering::SeqCst);
|
||||||
let mut shutdown = shutdown.lock().unwrap();
|
|
||||||
*shutdown = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for shards to shutdown
|
// Wait for shards to shutdown
|
||||||
// This will block till all shutdown_sender are dropped
|
// This will block till all shutdown_sender are dropped
|
||||||
@ -719,7 +757,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
fn spawn_shards(
|
fn spawn_shards(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
shutdown_sender: mpsc::Sender<()>,
|
shutdown_sender: mpsc::Sender<()>,
|
||||||
status_receiver: &mpsc::Receiver<ShardStatus>,
|
status_receiver: &mpsc::Receiver<ShardStatus>,
|
||||||
@ -749,6 +787,7 @@ fn spawn_shards(
|
|||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
|
let dtype = args.dtype;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
@ -759,6 +798,7 @@ fn spawn_shards(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
@ -793,7 +833,10 @@ fn spawn_shards(
|
|||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
Ok(ShardStatus::Failed((rank, err))) => {
|
Ok(ShardStatus::Failed((rank, err))) => {
|
||||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
tracing::error!("Shard {rank} failed to start");
|
||||||
|
if let Some(err) = err {
|
||||||
|
tracing::error!("{err}");
|
||||||
|
}
|
||||||
shutdown_shards(shutdown, shutdown_receiver);
|
shutdown_shards(shutdown, shutdown_receiver);
|
||||||
return Err(LauncherError::ShardCannotStart);
|
return Err(LauncherError::ShardCannotStart);
|
||||||
}
|
}
|
||||||
@ -809,7 +852,7 @@ fn spawn_shards(
|
|||||||
|
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
args: Args,
|
args: Args,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
) -> Result<Popen, LauncherError> {
|
) -> Result<Popen, LauncherError> {
|
||||||
// All shard started
|
// All shard started
|
||||||
@ -827,6 +870,10 @@ fn spawn_webserver(
|
|||||||
args.max_input_length.to_string(),
|
args.max_input_length.to_string(),
|
||||||
"--max-total-tokens".to_string(),
|
"--max-total-tokens".to_string(),
|
||||||
args.max_total_tokens.to_string(),
|
args.max_total_tokens.to_string(),
|
||||||
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
|
args.max_batch_prefill_tokens.to_string(),
|
||||||
|
"--max-batch-total-tokens".to_string(),
|
||||||
|
args.max_batch_total_tokens.to_string(),
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
@ -839,15 +886,6 @@ fn spawn_webserver(
|
|||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Deprecate max_batch_size
|
|
||||||
if let Some(max_batch_size) = args.max_batch_size {
|
|
||||||
argv.push("--max-batch-size".to_string());
|
|
||||||
argv.push(max_batch_size.to_string())
|
|
||||||
} else {
|
|
||||||
argv.push("--max-batch-total-tokens".to_string());
|
|
||||||
argv.push(args.max_batch_total_tokens.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = args.revision {
|
if let Some(ref revision) = args.revision {
|
||||||
argv.push("--revision".to_string());
|
argv.push("--revision".to_string());
|
||||||
@ -981,7 +1019,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
download_convert_model(&args, running.clone())?;
|
download_convert_model(&args, running.clone())?;
|
||||||
|
|
||||||
// Shared shutdown bool
|
// Shared shutdown bool
|
||||||
let shutdown = Arc::new(Mutex::new(false));
|
let shutdown = Arc::new(AtomicBool::new(false));
|
||||||
// Shared shutdown channel
|
// Shared shutdown channel
|
||||||
// When shutting down, the main thread will wait for all senders to be dropped
|
// When shutting down, the main thread will wait for all senders to be dropped
|
||||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||||
@ -1006,14 +1044,21 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
|
let mut webserver =
|
||||||
|
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
||||||
|
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
// Default exit code
|
// Default exit code
|
||||||
let mut exit_code = Ok(());
|
let mut exit_code = Ok(());
|
||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {rank} failed:\n{err}");
|
tracing::error!("Shard {rank} crashed");
|
||||||
|
if let Some(err) = err {
|
||||||
|
tracing::error!("{err}");
|
||||||
|
}
|
||||||
exit_code = Err(LauncherError::ShardFailed);
|
exit_code = Err(LauncherError::ShardFailed);
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
|
@ -11,6 +11,8 @@ service TextGenerationService {
|
|||||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
/// Remove requests from a cached batch
|
/// Remove requests from a cached batch
|
||||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||||
|
/// Warmup the model and compute max cache size
|
||||||
|
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||||
/// Prefill batch and decode first token
|
/// Prefill batch and decode first token
|
||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
@ -192,3 +194,13 @@ message DecodeResponse {
|
|||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional CachedBatch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message WarmupRequest {
|
||||||
|
/// Batch to warmup on
|
||||||
|
Batch batch = 1;
|
||||||
|
/// Maximum number of tokens that the client will send
|
||||||
|
uint32 max_total_tokens = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty response
|
||||||
|
message WarmupResponse {}
|
||||||
|
@ -22,11 +22,11 @@ text-generation-client = { path = "client" }
|
|||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
flume = "0.10.14"
|
flume = "0.10.14"
|
||||||
futures = "0.3.26"
|
futures = "0.3.26"
|
||||||
metrics = "0.20.1"
|
metrics = "0.21.0"
|
||||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.12.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.14", features = [] }
|
reqwest = { version = "0.11.14", features = [] }
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
@ -36,7 +36,7 @@ tokenizers = "0.13.3"
|
|||||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.19.0"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
@ -11,10 +11,10 @@ grpc-metadata = { path = "../grpc-metadata" }
|
|||||||
prost = "^0.11"
|
prost = "^0.11"
|
||||||
thiserror = "^1.0"
|
thiserror = "^1.0"
|
||||||
tokio = { version = "^1.25", features = ["sync"] }
|
tokio = { version = "^1.25", features = ["sync"] }
|
||||||
tonic = "^0.8"
|
tonic = "^0.9"
|
||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = "0.8.4"
|
tonic-build = "0.9.2"
|
||||||
prost-build = "0.11.6"
|
prost-build = "0.11.6"
|
||||||
|
@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
|
|||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v1::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use std::cmp::min;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -94,6 +95,63 @@ impl Client {
|
|||||||
Ok(filtered_batch.batch)
|
Ok(filtered_batch.batch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||||
|
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
watermark: true,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 2,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
|
@ -87,6 +87,27 @@ impl ShardedClient {
|
|||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
opentelemetry = "0.18.0"
|
opentelemetry = "^0.19"
|
||||||
tonic = "^0.8"
|
tonic = "^0.9"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "^0.19"
|
||||||
|
@ -45,6 +45,7 @@ impl Infer {
|
|||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
@ -61,6 +62,7 @@ impl Infer {
|
|||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
client,
|
client,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
@ -240,9 +242,11 @@ impl Infer {
|
|||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
@ -257,8 +261,9 @@ async fn batching_task(
|
|||||||
// Get the next batch from the queue
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) =
|
while let Some((mut entries, batch, span)) = queue
|
||||||
queue.next_batch(None, max_batch_total_tokens).await
|
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
@ -284,11 +289,12 @@ async fn batching_task(
|
|||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) =
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
queue.next_batch(min_size, token_budget).await
|
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
|
@ -28,15 +28,15 @@ struct Args {
|
|||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1024", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "1512", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "32000", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(default_value = "16000", long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
mut max_batch_total_tokens,
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
@ -97,8 +97,18 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
ngrok_password,
|
ngrok_password,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||||
|
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
|
||||||
|
}
|
||||||
|
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||||
|
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||||
|
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
|
||||||
|
}
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
panic!("validation_workers must be > 0");
|
panic!("`validation_workers` must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
@ -141,12 +151,6 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
|
||||||
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
|
||||||
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
|
||||||
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
||||||
@ -161,9 +165,15 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
|
false => get_model_info(&tokenizer_name, &revision, authorization_token)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -190,6 +200,17 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.info()
|
.info()
|
||||||
.await
|
.await
|
||||||
.expect("Unable to get shard info");
|
.expect("Unable to get shard info");
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_length as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Unable to warmup model");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Binds on localhost
|
// Binds on localhost
|
||||||
@ -206,6 +227,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
|
@ -58,6 +58,7 @@ impl Queue {
|
|||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -67,6 +68,7 @@ impl Queue {
|
|||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::NextBatch {
|
.send(QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
|||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
|
prefill_token_budget,
|
||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state.next_batch(min_size, token_budget);
|
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
@ -140,7 +143,12 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(
|
||||||
|
&mut self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@ -184,7 +192,9 @@ impl State {
|
|||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
|
||||||
if (prefill_tokens + decode_tokens) > token_budget {
|
if prefill_tokens > prefill_token_budget
|
||||||
|
|| (prefill_tokens + decode_tokens) > token_budget
|
||||||
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
@ -259,6 +269,7 @@ enum QueueCommand {
|
|||||||
Append(Box<Entry>, Span),
|
Append(Box<Entry>, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
span: Span,
|
span: Span,
|
||||||
@ -328,8 +339,8 @@ mod tests {
|
|||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false);
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -340,7 +351,7 @@ mod tests {
|
|||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -356,7 +367,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), 2).is_none());
|
assert!(state.next_batch(Some(2), 2, 2).is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -372,7 +383,7 @@ mod tests {
|
|||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -385,7 +396,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -408,8 +419,8 @@ mod tests {
|
|||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -420,7 +431,7 @@ mod tests {
|
|||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -433,11 +444,11 @@ mod tests {
|
|||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
// Not enough requests pending
|
// Not enough requests pending
|
||||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
|
||||||
// Not enough token budget
|
// Not enough token budget
|
||||||
assert!(queue.next_batch(Some(1), 0).await.is_none());
|
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
|
||||||
// Ok
|
// Ok
|
||||||
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap();
|
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
|
||||||
assert_eq!(entries2.len(), 1);
|
assert_eq!(entries2.len(), 1);
|
||||||
assert!(entries2.contains_key(&2));
|
assert!(entries2.contains_key(&2));
|
||||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
@ -453,7 +464,7 @@ mod tests {
|
|||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -462,7 +473,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
|
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -476,6 +487,6 @@ mod tests {
|
|||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -514,6 +514,7 @@ pub async fn run(
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
@ -531,6 +532,7 @@ pub async fn run(
|
|||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
|
health,
|
||||||
get_model_info,
|
get_model_info,
|
||||||
compat_generate,
|
compat_generate,
|
||||||
generate,
|
generate,
|
||||||
@ -582,6 +584,7 @@ pub async fn run(
|
|||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
channel = "1.69.0"
|
channel = "1.70.0"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
@ -1,4 +1,5 @@
|
|||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
|
include Makefile-vllm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
|
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||||
|
|
||||||
flash-attention:
|
flash-attention:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install packaging
|
pip install packaging
|
||||||
git clone https://github.com/OlivierDehaene/flash-attention.git
|
git clone https://github.com/HazyResearch/flash-attention.git
|
||||||
|
|
||||||
build-flash-attention: flash-attention
|
build-flash-attention: flash-attention
|
||||||
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
||||||
|
13
server/Makefile-vllm
Normal file
13
server/Makefile-vllm
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
|
||||||
|
|
||||||
|
vllm:
|
||||||
|
# Clone vllm
|
||||||
|
git clone https://github.com/OlivierDehaene/vllm.git
|
||||||
|
|
||||||
|
build-vllm: vllm
|
||||||
|
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||||
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
|
install-vllm: build-vllm
|
||||||
|
pip uninstall vllm -y || true
|
||||||
|
cd vllm && python setup.py install
|
1904
server/poetry.lock
generated
1904
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "0.8.2"
|
version = "0.9.0"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
@ -26,7 +26,8 @@ hf-transfer = "^0.1.2"
|
|||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
huggingface-hub = "^0.14.1"
|
huggingface-hub = "^0.14.1"
|
||||||
transformers = "^4.29.2"
|
transformers = "4.29.2"
|
||||||
|
einops = "^0.6.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
|
@ -1,21 +1,23 @@
|
|||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
|
numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
@ -26,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
|||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
|
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||||
|
@ -22,7 +22,9 @@ class Cache:
|
|||||||
del batch
|
del batch
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.cache.clear()
|
keys = list(self.cache.keys())
|
||||||
|
for k in keys:
|
||||||
|
self.delete(k)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.cache.keys())
|
return len(self.cache.keys())
|
||||||
|
@ -17,12 +17,18 @@ class Quantization(str, Enum):
|
|||||||
gptq_cuda = "gptq-cuda"
|
gptq_cuda = "gptq-cuda"
|
||||||
|
|
||||||
|
|
||||||
|
class Dtype(str, Enum):
|
||||||
|
float16 = "float16"
|
||||||
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
|
dtype: Optional[Dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -65,7 +71,14 @@ def serve(
|
|||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
|
dtype = None if dtype is None else dtype.value
|
||||||
|
if dtype is not None and quantize is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
)
|
||||||
|
server.serve(
|
||||||
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
|
|||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.bloom import BLOOMSharded
|
||||||
|
from text_generation_server.models.mpt import MPTSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.rw import RW
|
from text_generation_server.models.rw import RW
|
||||||
from text_generation_server.models.opt import OPTSharded
|
from text_generation_server.models.opt import OPTSharded
|
||||||
@ -100,11 +101,25 @@ def get_model(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "float16":
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "bfloat16":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
dtypetrust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_id.startswith("bigcode/"):
|
if model_id.startswith("bigcode/"):
|
||||||
@ -113,6 +128,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -124,6 +140,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -138,6 +155,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -149,11 +167,20 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
return BLOOMSharded(
|
return BLOOMSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif model_type == "mpt":
|
||||||
|
return MPTSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -163,6 +190,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -170,6 +198,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -177,6 +206,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -186,6 +216,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -195,6 +226,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -210,6 +242,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -221,6 +254,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -228,12 +262,17 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "opt":
|
elif model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "t5":
|
elif model_type == "t5":
|
||||||
@ -241,6 +280,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -253,11 +293,19 @@ def get_model(
|
|||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
@ -267,6 +315,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||||
@ -274,6 +323,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -454,11 +454,12 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -23,12 +23,16 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
@ -122,20 +126,21 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
@ -144,23 +149,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
vllm_cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = qkv[:, 1:]
|
|
||||||
|
|
||||||
# output
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
cu_seqlen_prefill,
|
||||||
start_seq,
|
|
||||||
end_seq,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -173,31 +178,19 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
block_size = kv_cache[1].shape[3]
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
layer_past[:, 0],
|
|
||||||
layer_past[:, 1],
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
qkv[:, 0],
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -265,14 +258,12 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -281,14 +272,12 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -333,40 +322,17 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values=None,
|
) -> torch.Tensor:
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -380,34 +346,17 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache[i],
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[:, i],
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
@ -423,31 +372,27 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -25,11 +25,15 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -110,20 +114,21 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
@ -132,23 +137,23 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
vllm_cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = qkv[:, 1:]
|
|
||||||
|
|
||||||
# output
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
cu_seqlen_prefill,
|
||||||
start_seq,
|
|
||||||
end_seq,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -161,31 +166,19 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
block_size = kv_cache[1].shape[3]
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
layer_past[:, 0],
|
|
||||||
layer_past[:, 1],
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
qkv[:, 0],
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -250,14 +243,12 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
@ -266,14 +257,12 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
ln1_hidden_states,
|
ln1_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
@ -292,14 +281,12 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -346,40 +333,17 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values=None,
|
) -> torch.Tensor:
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
||||||
@ -393,34 +357,17 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache[i],
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[:, i],
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.layers),
|
|
||||||
2,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
@ -434,31 +381,27 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.embed_out(hidden_states)
|
logits = self.embed_out(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -4,11 +4,15 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -126,19 +130,26 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.num_heads_kv == 1:
|
||||||
|
self.kv_head_mapping = torch.zeros(
|
||||||
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
@ -156,25 +167,27 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
vllm_cache_ops.reshape_and_cache(
|
||||||
if prefill:
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
# Copy to layer past
|
)
|
||||||
layer_past[...] = kv
|
|
||||||
# Expand to query shape
|
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
if self.num_heads_kv == 1:
|
||||||
|
# Expand to query shape
|
||||||
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
cu_seqlen_prefill,
|
||||||
start_seq,
|
|
||||||
end_seq,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -187,32 +200,19 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
||||||
layer_past[past_present_indices] = kv
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand to query shape
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -264,19 +264,21 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_groups, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
@ -293,10 +295,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
|
vllm_cache_ops.reshape_and_cache(
|
||||||
|
kv[:, :, 0].contiguous(),
|
||||||
|
kv[:, :, 1].contiguous(),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
slots,
|
||||||
|
)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if cu_seqlen_prefill is not None:
|
||||||
# Copy to layer past
|
|
||||||
layer_past[...] = kv
|
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
kv.unsqueeze(2)
|
kv.unsqueeze(2)
|
||||||
@ -304,18 +315,14 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
cu_seqlen_prefill,
|
||||||
start_seq,
|
|
||||||
end_seq,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -328,36 +335,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
||||||
layer_past[past_present_indices] = kv
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand to query shape
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
kv = (
|
|
||||||
layer_past.unsqueeze(2)
|
|
||||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
|
||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(kv, dim=2, index=0),
|
|
||||||
torch.select(kv, dim=2, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(
|
return self.dense(
|
||||||
@ -432,14 +422,12 @@ class FlashRWLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
@ -448,14 +436,12 @@ class FlashRWLayer(nn.Module):
|
|||||||
ln_hidden_states,
|
ln_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
@ -472,14 +458,12 @@ class FlashRWLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -523,14 +507,12 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
@ -540,14 +522,12 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
ln_attn,
|
ln_attn,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
@ -580,11 +560,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||||
2,
|
|
||||||
self.h[0].self_attention.num_heads_kv,
|
|
||||||
self.h[0].self_attention.head_size,
|
|
||||||
)
|
|
||||||
elif config.model_type == "RefinedWeb":
|
elif config.model_type == "RefinedWeb":
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -592,11 +568,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = (
|
self.cache_size = self.h[0].self_attention.num_groups
|
||||||
self.h[0].self_attention.num_groups,
|
|
||||||
2,
|
|
||||||
self.h[0].self_attention.head_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"model_type {config.model_type} is not supported."
|
f"model_type {config.model_type} is not supported."
|
||||||
@ -612,38 +584,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values=None,
|
) -> torch.Tensor:
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
len(input_ids),
|
|
||||||
len(self.h),
|
|
||||||
*self.cache_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
||||||
@ -657,32 +608,17 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache[i],
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(
|
|
||||||
pre_allocate_past_size,
|
|
||||||
len(self.h),
|
|
||||||
*self.cache_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
@ -697,31 +633,27 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
@ -3,10 +3,15 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -96,6 +101,7 @@ def _load_multi_mqa_gptq(
|
|||||||
def _load_multi_mqa(
|
def _load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -239,18 +245,19 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_head_mapping = torch.zeros(
|
||||||
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
@ -263,10 +270,15 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
|
vllm_cache_ops.reshape_and_cache(
|
||||||
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if cu_seqlen_prefill is not None:
|
||||||
# Copy to layer past
|
|
||||||
layer_past[...] = key_value
|
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -278,10 +290,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
cu_seqlen_prefill,
|
||||||
start_seq,
|
|
||||||
end_seq,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
@ -294,32 +304,19 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
||||||
layer_past[past_present_indices] = key_value
|
block_size = kv_cache[1].shape[3]
|
||||||
# Expand from 1 to num_heads
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
query,
|
|
||||||
torch.select(key_value, dim=1, index=0),
|
|
||||||
torch.select(key_value, dim=1, index=1),
|
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
query,
|
||||||
end_seq_q,
|
kv_cache[0],
|
||||||
start_seq,
|
kv_cache[1],
|
||||||
end_seq,
|
self.kv_head_mapping,
|
||||||
1,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
block_tables,
|
||||||
False,
|
input_lengths,
|
||||||
False,
|
block_size,
|
||||||
0,
|
max_s,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -379,27 +376,23 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
@ -445,64 +438,36 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values=None,
|
) -> torch.Tensor:
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if past_key_values is None:
|
|
||||||
assert pre_allocate_past_size is not None
|
|
||||||
|
|
||||||
prefill = True
|
|
||||||
|
|
||||||
# Create past tensor
|
|
||||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
|
||||||
past_key_values = hidden_states.new_zeros(
|
|
||||||
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
|
||||||
)
|
|
||||||
# Decode
|
|
||||||
else:
|
|
||||||
prefill = False
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache[i],
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
|
||||||
past_present_indices,
|
|
||||||
prefill,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
|
||||||
present = past_key_values
|
|
||||||
# Create padded past tensor
|
|
||||||
past_key_values = hidden_states.new_empty(
|
|
||||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
|
||||||
)
|
|
||||||
# We slice only once instead of at every layer
|
|
||||||
past_key_values[past_present_indices] = present
|
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
@ -539,31 +504,27 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
start_seq,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq,
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
start_seq_q,
|
block_tables: torch.Tensor,
|
||||||
end_seq_q,
|
slots: torch.Tensor,
|
||||||
max_s,
|
input_lengths: torch.Tensor,
|
||||||
past_present_indices,
|
max_s: int,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
hidden_states, present = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
start_seq,
|
cu_seqlen_prefill,
|
||||||
end_seq,
|
kv_cache,
|
||||||
start_seq_q,
|
block_tables,
|
||||||
end_seq_q,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
past_present_indices,
|
|
||||||
past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits, present
|
return logits
|
||||||
|
1140
server/text_generation_server/models/custom_modeling/mpt_modeling.py
Normal file
1140
server/text_generation_server/models/custom_modeling/mpt_modeling.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
try:
|
try:
|
||||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
|
self.shared = TensorParallelEmbedding(
|
||||||
|
prefix="encoder.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
import math
|
||||||
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -20,6 +23,92 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
# Will be set in warmup
|
||||||
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.block_size = BLOCK_SIZE
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
x = self.block_size // element_size
|
||||||
|
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, self.block_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||||
|
self.slots = torch.arange(
|
||||||
|
0, num_blocks * self.block_size, dtype=torch.int32
|
||||||
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
|
def allocate(self, batch: "FlashCausalLMBatch"):
|
||||||
|
# Get free blocks indices by finding values in mask that are not set to 0
|
||||||
|
free_block_indices = self.free_block_mask.nonzero()
|
||||||
|
assert (
|
||||||
|
len(free_block_indices) >= batch.blocks
|
||||||
|
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
|
||||||
|
|
||||||
|
# Slice by the number of required blocks
|
||||||
|
block_indices = free_block_indices[: batch.blocks]
|
||||||
|
block_indices = block_indices.flatten()
|
||||||
|
|
||||||
|
# Padded block tables
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate paged attention blocks
|
||||||
|
cumulative_blocks = 0
|
||||||
|
slots = []
|
||||||
|
block_tables = []
|
||||||
|
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
|
||||||
|
# Get allocated blocks for this sequence
|
||||||
|
allocated_blocks = block_indices[
|
||||||
|
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||||
|
]
|
||||||
|
# Get slots for the allocated blocks
|
||||||
|
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||||
|
|
||||||
|
slots.append(allocated_slots)
|
||||||
|
block_tables.append(allocated_blocks.tolist())
|
||||||
|
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||||
|
cumulative_blocks += needed_blocks
|
||||||
|
|
||||||
|
batch.needed_blocks_slots = None
|
||||||
|
batch.block_tables = block_tables
|
||||||
|
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
|
||||||
|
batch.slots = torch.concat(slots).to(batch.input_ids.device)
|
||||||
|
|
||||||
|
# Allocate the required number of blocks by setting the mask to 0
|
||||||
|
self.free_block_mask[block_indices] = 0
|
||||||
|
|
||||||
|
def free(self, block_indices: Optional[List[int]]):
|
||||||
|
if block_indices is not None and block_indices:
|
||||||
|
# Reset mask
|
||||||
|
self.free_block_mask[block_indices] = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -32,23 +121,29 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
|
||||||
# Indices to copy present to the correct indices is the pre-allocated past key values
|
# Flash Attention values
|
||||||
past_present_indices: torch.Tensor
|
|
||||||
|
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Paged Attention values
|
||||||
|
|
||||||
|
# Set when creating the batch
|
||||||
|
# CPU tensor of length b indicating the start of each sequence in slots
|
||||||
|
start_slots: torch.Tensor
|
||||||
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
|
slot_indices: torch.Tensor
|
||||||
|
# List of tuple of ints representing the number of blocks and slots needed by each sequence
|
||||||
|
needed_blocks_slots: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
# Set in prefill by the CacheManager
|
||||||
|
# list of length b of list of length s_i // block_size
|
||||||
|
block_tables: Optional[List[List[int]]]
|
||||||
|
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
|
block_tables_tensor: Optional[torch.Tensor]
|
||||||
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
|
slots: Optional[torch.Tensor]
|
||||||
|
|
||||||
# tensor of length b holding starting offset of each sequence
|
|
||||||
start_seq: torch.Tensor
|
|
||||||
# tensor of length b holding ending offset of each sequence
|
|
||||||
end_seq: torch.Tensor
|
|
||||||
# tensor of length b holding starting offset of each sequence, only used in prefill
|
|
||||||
start_seq_prefill: Optional[torch.Tensor]
|
|
||||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
|
||||||
end_seq_prefill: Optional[torch.Tensor]
|
|
||||||
# tensor of length b holding starting offset of each query sequence, only used in decode
|
|
||||||
start_seq_q: Optional[torch.Tensor]
|
|
||||||
# tensor of length b holding ending offset of each query sequence, only used in decode
|
|
||||||
end_seq_q: Optional[torch.Tensor]
|
|
||||||
# past key values, only used in decode
|
|
||||||
past_key_values: Optional[torch.Tensor]
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
# Prefill metadata tensors to efficiently compute logprobs
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
@ -62,6 +157,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
input_lengths_tensor: torch.Tensor
|
||||||
prefix_offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
read_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
@ -69,15 +165,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Maximum number of tokens this batch will grow to
|
# Number of blocks in this batch
|
||||||
max_tokens: int
|
blocks: int
|
||||||
|
# Maximum number of blocks
|
||||||
|
max_blocks: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -99,12 +197,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
past_present_indices = []
|
cu_seqlen_prefill = [0]
|
||||||
start_seq = []
|
needed_blocks_slots = []
|
||||||
end_seq = []
|
start_slots = []
|
||||||
start_seq_prefill = []
|
slot_indices = []
|
||||||
end_seq_prefill = []
|
|
||||||
max_seqlen = 0
|
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
@ -126,7 +222,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
|
max_seqlen = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
max_blocks = 0
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -138,7 +237,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
@ -151,10 +249,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
start_seq_prefill.append(cumulative_length)
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||||
end_seq_prefill.append(cumulative_length + input_length)
|
|
||||||
start_seq.append(cumulative_max_length)
|
|
||||||
end_seq.append(cumulative_max_length + input_length)
|
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
@ -164,6 +259,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
# Paged attention
|
||||||
|
# Remove one as the first token des not have a past
|
||||||
|
total_tokens = input_length + max_new_tokens - 1
|
||||||
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
|
blocks += needed_blocks
|
||||||
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
|
request_slot_indices = torch.arange(
|
||||||
|
cumulative_max_length,
|
||||||
|
cumulative_max_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
@ -184,22 +294,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
request_past_present_indices = torch.arange(
|
|
||||||
cumulative_max_length,
|
|
||||||
cumulative_max_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
past_present_indices.append(request_past_present_indices)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
# Remove one as the first token des not have a past
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += input_length + max_new_tokens - 1
|
cumulative_max_length += total_tokens
|
||||||
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
)
|
)
|
||||||
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
# Padded all_input_ids_tensor
|
||||||
all_input_ids_tensor = np.zeros(
|
all_input_ids_tensor = np.zeros(
|
||||||
@ -212,41 +317,32 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
|
||||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
|
slot_indices = torch.cat(slot_indices)
|
||||||
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
|
||||||
|
|
||||||
start_seq_prefill = torch.tensor(
|
|
||||||
start_seq_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
end_seq_prefill = torch.tensor(
|
|
||||||
end_seq_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
|
slot_indices = slot_indices[0]
|
||||||
|
|
||||||
past_present_indices = past_present_indices[0]
|
cu_seqlen_prefill = torch.tensor(
|
||||||
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
start_seq_prefill = start_seq
|
)
|
||||||
end_seq_prefill = end_seq
|
|
||||||
|
|
||||||
|
position_ids = position_ids.to(device)
|
||||||
|
slot_indices = slot_indices.to(device)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
input_lengths_tensor = torch.tensor(
|
||||||
past_present_indices = torch.tensor(
|
input_lengths, dtype=torch.int32, device=device
|
||||||
past_present_indices, device=device, dtype=torch.int64
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = end_seq_prefill - 1
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||||
elif no_prefill_logprobs:
|
elif no_prefill_logprobs:
|
||||||
prefill_head_indices = end_seq_prefill - 1
|
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.tensor(
|
prefill_head_indices = torch.tensor(
|
||||||
@ -262,26 +358,27 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
start_seq=start_seq,
|
start_slots=start_slots,
|
||||||
end_seq=end_seq,
|
slot_indices=slot_indices,
|
||||||
start_seq_prefill=start_seq_prefill,
|
needed_blocks_slots=needed_blocks_slots,
|
||||||
end_seq_prefill=end_seq_prefill,
|
block_tables=None,
|
||||||
start_seq_q=None,
|
block_tables_tensor=None,
|
||||||
end_seq_q=None,
|
slots=None,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
past_key_values=None,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=cumulative_max_length,
|
blocks=blocks,
|
||||||
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -294,28 +391,24 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
|
||||||
# Cumulative length
|
|
||||||
cumulative_max_length = 0
|
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# past indices to keep
|
# slots to keep after filtering
|
||||||
past_indices = torch.zeros(
|
slot_filtering_indices = torch.zeros(
|
||||||
self.past_key_values.shape[0], dtype=torch.bool, device=device
|
self.slots.shape[0], dtype=torch.bool, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
|
||||||
start_seq_q = self.start_seq_q[: len(request_ids)]
|
|
||||||
end_seq_q = self.end_seq_q[: len(request_ids)]
|
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
|
start_slots = []
|
||||||
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -324,6 +417,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
|
max_blocks = 0
|
||||||
|
# Cumulative length
|
||||||
|
cumulative_max_length = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
@ -348,28 +446,51 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_block_table = self.block_tables[idx]
|
||||||
|
blocks += len(request_block_table)
|
||||||
|
block_tables.append(request_block_table)
|
||||||
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
start_seq[i] = cumulative_max_length
|
slot_indices[i] = cumulative_max_length + request_input_length - 1
|
||||||
end_seq[i] = cumulative_max_length + request_input_length
|
|
||||||
|
|
||||||
# Set slice
|
# Set slice
|
||||||
past_indices[
|
slot_filtering_indices[
|
||||||
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
|
self.start_slots[idx] : self.start_slots[idx]
|
||||||
|
+ request_input_length
|
||||||
|
+ remaining_tokens
|
||||||
|
- 1
|
||||||
] = True
|
] = True
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||||
|
|
||||||
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
|
global CACHE_MANAGER
|
||||||
|
block_indices_to_free = []
|
||||||
|
# Iterate on all requests
|
||||||
|
for i, r in enumerate(self.requests):
|
||||||
|
# Filter requests that are not part of the new batch
|
||||||
|
if r.id not in requests_idx_mapping.keys():
|
||||||
|
block_indices_to_free.extend(self.block_tables[i])
|
||||||
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(block_indices_to_free)
|
||||||
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
|
self.block_tables = None
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
past_key_values = self.past_key_values[past_indices]
|
|
||||||
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
start_seq = start_seq.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
end_seq = end_seq.to(device)
|
|
||||||
past_present_indices = end_seq - 1
|
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
@ -377,26 +498,27 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
cu_seqlen_prefill=None,
|
||||||
start_seq=start_seq,
|
start_slots=start_slots,
|
||||||
end_seq=end_seq,
|
slot_indices=slot_indices,
|
||||||
start_seq_prefill=None,
|
needed_blocks_slots=None,
|
||||||
end_seq_prefill=None,
|
block_tables=block_tables,
|
||||||
start_seq_q=start_seq_q,
|
block_tables_tensor=block_tables_tensor,
|
||||||
end_seq_q=end_seq_q,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=cumulative_max_length,
|
blocks=blocks,
|
||||||
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -406,22 +528,46 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
total_batch_size = sum([len(b) for b in batches])
|
blocks = 0
|
||||||
|
total_batch_size = 0
|
||||||
dtype = batches[0].past_key_values.dtype
|
total_slots = 0
|
||||||
device = batches[0].input_ids.device
|
max_blocks = 0
|
||||||
|
max_length = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
for b in batches:
|
||||||
|
total_batch_size += len(b)
|
||||||
|
total_slots += len(b.slots)
|
||||||
|
blocks += b.blocks
|
||||||
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
|
max_length = max(
|
||||||
|
max_length,
|
||||||
|
max(
|
||||||
|
input_length
|
||||||
|
+ stopping_criteria.max_new_tokens
|
||||||
|
- stopping_criteria.current_tokens
|
||||||
|
for input_length, stopping_criteria in zip(
|
||||||
|
b.input_lengths, b.stopping_criterias
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
start_seq = batches[0].start_seq.new_empty(total_batch_size)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
end_seq = batches[0].end_seq.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
start_seq_q = torch.arange(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
0, total_batch_size, device=device, dtype=torch.int32
|
total_batch_size
|
||||||
|
)
|
||||||
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
|
(total_batch_size, max_blocks)
|
||||||
|
)
|
||||||
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
end_seq_q = start_seq_q + 1
|
|
||||||
max_seqlen = 0
|
|
||||||
past_key_values = []
|
|
||||||
|
|
||||||
|
start_slots = []
|
||||||
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -433,8 +579,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
max_tokens = 0
|
cumulative_slots = 0
|
||||||
max_length = 0
|
|
||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
@ -448,16 +593,27 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
|
slots_start_index = cumulative_slots
|
||||||
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
all_input_ids_tensor[
|
||||||
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
block_tables_tensor[
|
||||||
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
|
block_tables.extend(batch.block_tables)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
@ -466,73 +622,58 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
past_key_values.append(batch.past_key_values)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
max_tokens += batch.max_tokens
|
cumulative_slots += len(batch.slots)
|
||||||
max_length = max(
|
|
||||||
max_length,
|
|
||||||
max(
|
|
||||||
input_length
|
|
||||||
+ stopping_criteria.max_new_tokens
|
|
||||||
- stopping_criteria.current_tokens
|
|
||||||
for input_length, stopping_criteria in zip(
|
|
||||||
batch.input_lengths, batch.stopping_criterias
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_values = torch.cat(past_key_values, dim=0)
|
start_slots = torch.concat(start_slots)
|
||||||
past_present_indices = end_seq - 1
|
|
||||||
|
|
||||||
all_input_ids_tensor = torch.zeros(
|
|
||||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
cumulative_batch_size = 0
|
|
||||||
for i, batch in enumerate(batches):
|
|
||||||
start_index = cumulative_batch_size
|
|
||||||
end_index = cumulative_batch_size + len(batch)
|
|
||||||
|
|
||||||
all_input_ids_tensor[
|
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype=dtype, device=device
|
next_token_chooser_parameters,
|
||||||
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
|
for b in batches:
|
||||||
|
b.block_tables = None
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_present_indices=past_present_indices,
|
cu_seqlen_prefill=None,
|
||||||
start_seq=start_seq,
|
start_slots=start_slots,
|
||||||
end_seq=end_seq,
|
slot_indices=slot_indices,
|
||||||
start_seq_prefill=None,
|
needed_blocks_slots=None,
|
||||||
end_seq_prefill=None,
|
block_tables=block_tables,
|
||||||
start_seq_q=start_seq_q,
|
block_tables_tensor=block_tables_tensor,
|
||||||
end_seq_q=end_seq_q,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
blocks=blocks,
|
||||||
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.block_tables is not None and self.block_tables:
|
||||||
|
global CACHE_MANAGER
|
||||||
|
# Free blocks
|
||||||
|
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
@ -540,32 +681,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type[PreTrainedModel],
|
model: torch.nn.Module,
|
||||||
model_id: str,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
revision: Optional[str] = None,
|
num_layers: int,
|
||||||
quantize: Optional[str] = None,
|
num_kv_heads: int,
|
||||||
trust_remote_code: bool = False,
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
self.num_layers = num_layers
|
||||||
device = torch.device("cuda")
|
self.num_kv_heads = num_kv_heads
|
||||||
dtype = torch.float16
|
self.head_size = head_size
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
model = model_cls.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -573,12 +701,38 @@ class FlashCausalLM(Model):
|
|||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
|
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
||||||
|
global CACHE_MANAGER
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
try:
|
||||||
|
CACHE_MANAGER = CacheManager(
|
||||||
|
# Adds some wiggle room
|
||||||
|
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.dtype,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
_, batch = self.generate_token(batch)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||||
|
f"prefill tokens. "
|
||||||
|
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
del batch
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
@ -588,28 +742,25 @@ class FlashCausalLM(Model):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
start_seq: torch.Tensor,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
end_seq: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
start_seq_q: Optional[torch.Tensor],
|
slots: torch.Tensor,
|
||||||
end_seq_q: Optional[torch.Tensor],
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_present_indices: torch.Tensor,
|
|
||||||
past_key_values: Optional = None,
|
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
global CACHE_MANAGER
|
||||||
|
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
start_seq=start_seq,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
end_seq=end_seq,
|
kv_cache=CACHE_MANAGER.kv_cache,
|
||||||
start_seq_q=start_seq_q,
|
block_tables=block_tables,
|
||||||
end_seq_q=end_seq_q,
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
past_present_indices=past_present_indices,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
pre_allocate_past_size=pre_allocate_past_size,
|
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -617,31 +768,21 @@ class FlashCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
if prefill:
|
if batch.needed_blocks_slots:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Allocate blocks to this batch
|
||||||
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
CACHE_MANAGER.allocate(batch)
|
||||||
pre_allocate_past_size = batch.max_tokens
|
|
||||||
start_seq = batch.start_seq_prefill
|
|
||||||
end_seq = batch.end_seq_prefill
|
|
||||||
else:
|
|
||||||
pre_allocate_past_size = None
|
|
||||||
start_seq = batch.start_seq
|
|
||||||
end_seq = batch.end_seq
|
|
||||||
|
|
||||||
out, present = self.forward(
|
out = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
start_seq,
|
batch.cu_seqlen_prefill,
|
||||||
end_seq,
|
batch.block_tables_tensor,
|
||||||
batch.start_seq_q,
|
batch.slots[batch.slot_indices],
|
||||||
batch.end_seq_q,
|
batch.input_lengths_tensor,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
batch.past_present_indices,
|
|
||||||
batch.past_key_values,
|
|
||||||
pre_allocate_past_size,
|
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -662,15 +803,10 @@ class FlashCausalLM(Model):
|
|||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
# Create batch.start_seq_q and batch.end_seq_q for decode
|
|
||||||
batch.start_seq_q = torch.arange(
|
|
||||||
0, len(batch), device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
batch.end_seq_q = batch.start_seq_q + 1
|
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
# We do not need start_seq_prefill and end_seq_prefill anymore
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
batch.start_seq_prefill = None
|
# We do not need cu_seqlen_prefill anymore
|
||||||
batch.end_seq_prefill = None
|
batch.cu_seqlen_prefill = None
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
@ -731,8 +867,8 @@ class FlashCausalLM(Model):
|
|||||||
# Set values in batch
|
# Set values in batch
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.past_present_indices = batch.end_seq
|
batch.input_lengths_tensor += 1
|
||||||
batch.end_seq = batch.end_seq + 1
|
batch.slot_indices += 1
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -755,7 +891,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.all_input_ids_tensor,
|
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
@ -770,7 +905,6 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_input_ids_tensor,
|
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
@ -845,19 +979,20 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
new_input_length = input_length + 1
|
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = input_length + 1
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
del batch
|
||||||
|
# No need to return a batch if we know that all requests stopped
|
||||||
|
return generations, None
|
||||||
|
|
||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
batch.past_key_values = present
|
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
return generations, batch
|
||||||
return generations, batch if not stopped else None
|
|
||||||
|
@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
@ -64,10 +65,12 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.model.layers),
|
||||||
|
num_kv_heads=model.model.num_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
@ -55,10 +56,12 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashNeoXSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.gpt_neox.layers),
|
||||||
|
num_kv_heads=model.gpt_neox.num_heads,
|
||||||
|
head_size=model.gpt_neox.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
@ -55,10 +56,12 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model = FlashRWForCausalLM(config, weights)
|
model = FlashRWForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashRWSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.transformer.h),
|
||||||
|
num_kv_heads=model.transformer.cache_size,
|
||||||
|
head_size=model.transformer.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
@ -52,18 +53,22 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
filenames,
|
||||||
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashSantacoderSharded, self).__init__(
|
||||||
super(FlashCausalLM, self).__init__(
|
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
num_layers=len(model.transformer.h),
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=model.transformer.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -22,6 +22,9 @@ class Model(ABC):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
|
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
@ -55,6 +58,9 @@ class Model(ABC):
|
|||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def warmup(self, batch: B, max_total_tokens: int):
|
||||||
|
self.generate_token(batch)
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
|
98
server/text_generation_server/models/mpt.py
Normal file
98
server/text_generation_server/models/mpt.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Type
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
|
MPTForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MPTCausalLMBatch(CausalLMBatch):
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "CausalLMBatch":
|
||||||
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||||
|
batch.keys_head_dim_last = False
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class MPTSharded(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("MPTSharded is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# If model_id is a local path, load the file directly
|
||||||
|
local_path = Path(model_id, "config.json")
|
||||||
|
if local_path.exists():
|
||||||
|
filename = str(local_path.resolve())
|
||||||
|
else:
|
||||||
|
filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = PretrainedConfig(**config)
|
||||||
|
config.quantize = quantize
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
model = MPTForCausalLM(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
|
return MPTCausalLMBatch
|
@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -12,11 +12,12 @@ class RW(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
|
async def Warmup(self, request, context):
|
||||||
|
batch = self.model.batch_type.from_pb(
|
||||||
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
|
)
|
||||||
|
self.model.warmup(batch, request.max_total_tokens)
|
||||||
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
@ -99,6 +106,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
@ -107,6 +115,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
@ -121,7 +130,9 @@ def serve(
|
|||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
|
model = get_model(
|
||||||
|
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
@ -152,4 +163,6 @@ def serve(
|
|||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))
|
asyncio.run(
|
||||||
|
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||||
|
)
|
||||||
|
@ -32,7 +32,19 @@ def load_layer_norm(cls, prefix, weights, eps):
|
|||||||
return ln
|
return ln
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
with init_empty_weights():
|
||||||
|
ln = cls(weight.shape, eps=eps)
|
||||||
|
|
||||||
|
ln.weight = nn.Parameter(weight)
|
||||||
|
ln.bias = None
|
||||||
|
return ln
|
||||||
|
|
||||||
|
|
||||||
torch.nn.LayerNorm.load = load_layer_norm
|
torch.nn.LayerNorm.load = load_layer_norm
|
||||||
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(nn.Module):
|
class FastLinear(nn.Module):
|
||||||
@ -357,7 +369,7 @@ try:
|
|||||||
def __init__(self, inv_freq):
|
def __init__(self, inv_freq):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.inv_freq = inv_freq
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
|
@ -189,7 +189,6 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||||
if self.min_tokens_to_keep > 1:
|
|
||||||
# Keep at least min_tokens_to_keep
|
# Keep at least min_tokens_to_keep
|
||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
|
@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
@ -3,8 +3,16 @@ from typing import List, Dict, Optional
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
filenames: List[Path],
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group,
|
||||||
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user