mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Add changes from Optimum Habana's TGI folder
This commit is contained in:
parent
7a6fad6aac
commit
cc744ba426
159
Dockerfile
159
Dockerfile
@ -2,8 +2,6 @@
|
|||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
|
||||||
|
|
||||||
FROM chef as planner
|
FROM chef as planner
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -15,9 +13,6 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
|
|
||||||
ARG GIT_SHA
|
|
||||||
ARG DOCKER_LABEL
|
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
@ -35,123 +30,18 @@ COPY router router
|
|||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --release
|
||||||
|
|
||||||
# Python builder
|
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
|
||||||
FROM debian:bullseye-slim as pytorch-install
|
|
||||||
|
|
||||||
ARG PYTORCH_VERSION=2.0.1
|
|
||||||
ARG PYTHON_VERSION=3.9
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
|
||||||
ARG CUDA_VERSION=11.8
|
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
|
||||||
ARG CUDA_CHANNEL=nvidia
|
|
||||||
ARG INSTALL_CHANNEL=pytorch
|
|
||||||
# Automatically set by buildx
|
|
||||||
ARG TARGETPLATFORM
|
|
||||||
|
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
||||||
build-essential \
|
|
||||||
ca-certificates \
|
|
||||||
ccache \
|
|
||||||
curl \
|
|
||||||
git && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install conda
|
|
||||||
# translating Docker's TARGETPLATFORM into mamba arches
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
|
||||||
*) MAMBA_ARCH=x86_64 ;; \
|
|
||||||
esac && \
|
|
||||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
||||||
RUN chmod +x ~/mambaforge.sh && \
|
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
||||||
rm ~/mambaforge.sh
|
|
||||||
|
|
||||||
# Install pytorch
|
|
||||||
# On arm64 we exit with an error code
|
|
||||||
RUN case ${TARGETPLATFORM} in \
|
|
||||||
"linux/arm64") exit 1 ;; \
|
|
||||||
*) /opt/conda/bin/conda update -y conda && \
|
|
||||||
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
|
||||||
esac && \
|
|
||||||
/opt/conda/bin/conda clean -ya
|
|
||||||
|
|
||||||
# CUDA kernels builder image
|
|
||||||
FROM pytorch-install as kernel-builder
|
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
||||||
ninja-build \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
|
|
||||||
/opt/conda/bin/conda clean -ya
|
|
||||||
|
|
||||||
# Build Flash Attention CUDA kernels
|
|
||||||
FROM kernel-builder as flash-att-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-flash-att Makefile
|
|
||||||
|
|
||||||
# Build specific version of flash attention
|
|
||||||
RUN make build-flash-attention
|
|
||||||
|
|
||||||
# Build Flash Attention v2 CUDA kernels
|
|
||||||
FROM kernel-builder as flash-att-v2-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-flash-att-v2 Makefile
|
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
|
||||||
RUN make build-flash-attention-v2
|
|
||||||
|
|
||||||
# Build Transformers exllama kernels
|
|
||||||
FROM kernel-builder as exllama-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllama_kernels/ .
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
|
||||||
|
|
||||||
# Build Transformers awq kernels
|
|
||||||
FROM kernel-builder as awq-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/Makefile-awq Makefile
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
|
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
|
||||||
FROM kernel-builder as custom-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/custom_kernels/ .
|
|
||||||
# Build specific version of transformers
|
|
||||||
RUN python setup.py build
|
|
||||||
|
|
||||||
# Build vllm CUDA kernels
|
|
||||||
FROM kernel-builder as vllm-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
|
||||||
|
|
||||||
# Build specific version of vllm
|
|
||||||
RUN make build-vllm
|
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
FROM vault.habana.ai/gaudi-docker/1.13.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.0:latest as base
|
||||||
|
|
||||||
# Conda env
|
|
||||||
ENV PATH=/opt/conda/bin:$PATH \
|
|
||||||
CONDA_PREFIX=/opt/conda
|
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
@ -161,30 +51,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy conda with PyTorch installed
|
|
||||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
|
||||||
|
|
||||||
# Copy build artifacts from flash attention builder
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
# Copy build artifacts from awq kernels builder
|
|
||||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
|
||||||
RUN pip install einops --no-cache-dir
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
@ -192,7 +58,7 @@ COPY server/Makefile server/Makefile
|
|||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements.txt && \
|
pip install -r requirements.txt && \
|
||||||
pip install ".[bnb, accelerate, quantize]" --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -201,19 +67,6 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
||||||
build-essential \
|
|
||||||
g++ \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# AWS Sagemaker compatbile image
|
|
||||||
FROM base as sagemaker
|
|
||||||
|
|
||||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
|
||||||
RUN chmod +x entrypoint.sh
|
|
||||||
|
|
||||||
ENTRYPOINT ["./entrypoint.sh"]
|
|
||||||
|
|
||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
|
||||||
|
9
Makefile
9
Makefile
@ -1,9 +1,6 @@
|
|||||||
install-server:
|
install-server:
|
||||||
cd server && make install
|
cd server && make install
|
||||||
|
|
||||||
install-custom-kernels:
|
|
||||||
if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
|
|
||||||
|
|
||||||
install-integration-tests:
|
install-integration-tests:
|
||||||
cd integration-tests && pip install -r requirements.txt
|
cd integration-tests && pip install -r requirements.txt
|
||||||
cd clients/python && pip install .
|
cd clients/python && pip install .
|
||||||
@ -45,8 +42,8 @@ python-tests: python-server-tests python-client-tests
|
|||||||
run-falcon-7b-instruct:
|
run-falcon-7b-instruct:
|
||||||
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
|
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
|
||||||
|
|
||||||
run-falcon-7b-instruct-quantize:
|
|
||||||
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -rf target aml
|
rm -rf target aml
|
||||||
|
|
||||||
|
debug_image_build:
|
||||||
|
docker build --no-cache --progress=plain -t debug_tgi .
|
||||||
|
359
README.md
359
README.md
@ -1,288 +1,83 @@
|
|||||||
|
<!---
|
||||||
|
Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Text Generation Inference on Habana Gaudi
|
||||||
|
|
||||||
|
To use [🤗 text-generation-inference](https://github.com/huggingface/text-generation-inference) on Habana Gaudi/Gaudi2, follow these steps:
|
||||||
|
|
||||||
|
1. Build the Docker image located in this folder with:
|
||||||
|
```bash
|
||||||
|
docker build -t tgi_gaudi .
|
||||||
|
```
|
||||||
|
2. Launch a local server instance on 1 Gaudi card:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-7b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model
|
||||||
|
```
|
||||||
|
3. Launch a local server instance on 8 Gaudi cards:
|
||||||
|
```bash
|
||||||
|
model=meta-llama/Llama-2-70b-hf
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host tgi_gaudi --model-id $model --sharded true --num-shard 8
|
||||||
|
```
|
||||||
|
4. You can then send a request:
|
||||||
|
```bash
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
> The first call will be slower as the model is compiled.
|
||||||
|
5. To run benchmark test, please refer [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark).
|
||||||
|
|
||||||
|
To run it on the same machine, you can do the following:
|
||||||
|
* `docker exec -it <docker name> bash` , pick the docker started from step 3 or 4 using docker ps
|
||||||
|
* `text-generation-benchmark -t <model-id>` , pass the model-id from docker run command
|
||||||
|
* after the completion of tests, hit ctrl+c to see the performance data summary.
|
||||||
|
|
||||||
|
> For gated models such as [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to pass `-e HUGGING_FACE_HUB_TOKEN=<token>` to the `docker run` command above with a valid Hugging Face Hub read token.
|
||||||
|
|
||||||
|
For more information and documentation about Text Generation Inference, checkout [the README](https://github.com/huggingface/text-generation-inference#text-generation-inference) of the original repo.
|
||||||
|
|
||||||
|
Not all features of TGI are currently supported as this is still a work in progress.
|
||||||
|
|
||||||
|
New changes are added for the current release:
|
||||||
|
- Sharded feature with support for DeepSpeed-inference auto tensor parallism. Also use HPU graph for performance improvement.
|
||||||
|
- Torch profile.
|
||||||
|
|
||||||
|
|
||||||
|
Enviroment Variables Added:
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||

|
| Name | Value(s) | Default | Description | Usage |
|
||||||
|
|------------------ |:---------------|:------------|:-------------------- |:---------------------------------
|
||||||
# Text Generation Inference
|
| MAX_TOTAL_TOKENS | integer | 0 | Control the padding of input | add -e in docker run, such |
|
||||||
|
| ENABLE_HPU_GRAPH | true/false | true | Enable hpu graph or not | add -e in docker run command |
|
||||||
<a href="https://github.com/huggingface/text-generation-inference">
|
| PROF_WARMUPSTEP | integer | 0 | Enable/disable profile, control profile warmup step, 0 means disable profile | add -e in docker run command |
|
||||||
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social">
|
| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command |
|
||||||
</a>
|
| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command |
|
||||||
<a href="https://huggingface.github.io/text-generation-inference">
|
| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command |
|
||||||
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
|
||||||
to power Hugging Chat, the Inference API and Inference Endpoint.
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Table of contents
|
|
||||||
|
|
||||||
- [Features](#features)
|
|
||||||
- [Optimized Architectures](#optimized-architectures)
|
|
||||||
- [Get Started](#get-started)
|
|
||||||
- [Docker](#docker)
|
|
||||||
- [API Documentation](#api-documentation)
|
|
||||||
- [Using a private or gated model](#using-a-private-or-gated-model)
|
|
||||||
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
|
|
||||||
- [Distributed Tracing](#distributed-tracing)
|
|
||||||
- [Local Install](#local-install)
|
|
||||||
- [CUDA Kernels](#cuda-kernels)
|
|
||||||
- [Run Falcon](#run-falcon)
|
|
||||||
- [Run](#run)
|
|
||||||
- [Quantization](#quantization)
|
|
||||||
- [Develop](#develop)
|
|
||||||
- [Testing](#testing)
|
|
||||||
- [Other supported hardware](#other-supported-hardware)
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Serve the most popular Large Language Models with a simple launcher
|
|
||||||
- Tensor Parallelism for faster inference on multiple GPUs
|
|
||||||
- Token streaming using Server-Sent Events (SSE)
|
|
||||||
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
|
|
||||||
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
|
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
|
||||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
|
||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
|
||||||
- Stop sequences
|
|
||||||
- Log probabilities
|
|
||||||
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
|
|
||||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output.
|
|
||||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
|
|
||||||
|
|
||||||
|
|
||||||
## Optimized architectures
|
|
||||||
|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
|
||||||
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
|
|
||||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
|
||||||
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
|
||||||
- [Llama](https://github.com/facebookresearch/llama)
|
|
||||||
- [OPT](https://huggingface.co/facebook/opt-66b)
|
|
||||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
|
||||||
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
|
||||||
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
|
|
||||||
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
|
||||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
|
||||||
- [Llama V2](https://huggingface.co/meta-llama)
|
|
||||||
- [Code Llama](https://huggingface.co/codellama)
|
|
||||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
|
||||||
|
|
||||||
Other architectures are supported on a best effort basis using:
|
|
||||||
|
|
||||||
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
|
|
||||||
|
|
||||||
or
|
|
||||||
|
|
||||||
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
|
|
||||||
|
|
||||||
## Get started
|
|
||||||
|
|
||||||
### Docker
|
|
||||||
|
|
||||||
The easiest way of getting started is using the official Docker container:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
model=tiiuae/falcon-7b-instruct
|
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model
|
|
||||||
```
|
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
|
||||||
```
|
|
||||||
text-generation-launcher --help
|
|
||||||
```
|
|
||||||
|
|
||||||
You can then query the model using either the `/generate` or `/generate_stream` routes:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl 127.0.0.1:8080/generate \
|
|
||||||
-X POST \
|
|
||||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
|
||||||
-H 'Content-Type: application/json'
|
|
||||||
```
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl 127.0.0.1:8080/generate_stream \
|
|
||||||
-X POST \
|
|
||||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
|
||||||
-H 'Content-Type: application/json'
|
|
||||||
```
|
|
||||||
|
|
||||||
or from Python:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install text-generation
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
from text_generation import Client
|
|
||||||
|
|
||||||
client = Client("http://127.0.0.1:8080")
|
|
||||||
print(client.generate("What is Deep Learning?", max_new_tokens=20).generated_text)
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=20):
|
|
||||||
if not response.token.special:
|
|
||||||
text += response.token.text
|
|
||||||
print(text)
|
|
||||||
```
|
|
||||||
|
|
||||||
### API documentation
|
|
||||||
|
|
||||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
|
||||||
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
|
||||||
|
|
||||||
### Using a private or gated model
|
|
||||||
|
|
||||||
You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by
|
|
||||||
`text-generation-inference`. This allows you to gain access to protected resources.
|
|
||||||
|
|
||||||
For example, if you want to serve the gated Llama V2 model variants:
|
|
||||||
|
|
||||||
1. Go to https://huggingface.co/settings/tokens
|
|
||||||
2. Copy your cli READ token
|
|
||||||
3. Export `HUGGING_FACE_HUB_TOKEN=<your cli READ token>`
|
|
||||||
|
|
||||||
or with Docker:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
model=meta-llama/Llama-2-7b-chat-hf
|
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
|
||||||
token=<your cli READ token>
|
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model
|
|
||||||
```
|
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
|
||||||
|
|
||||||
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
|
|
||||||
`PyTorch` to do distributed training/inference. `text-generation-inference` make
|
|
||||||
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
|
|
||||||
|
|
||||||
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
|
|
||||||
peer-to-peer using NVLink or PCI is not possible.
|
|
||||||
|
|
||||||
To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
|
|
||||||
|
|
||||||
If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
|
|
||||||
creating a volume with:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
- name: shm
|
|
||||||
emptyDir:
|
|
||||||
medium: Memory
|
|
||||||
sizeLimit: 1Gi
|
|
||||||
```
|
|
||||||
|
|
||||||
and mounting it to `/dev/shm`.
|
|
||||||
|
|
||||||
Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
|
|
||||||
this will impact performance.
|
|
||||||
|
|
||||||
### Distributed Tracing
|
|
||||||
|
|
||||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
|
||||||
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
|
|
||||||
|
|
||||||
### Local install
|
|
||||||
|
|
||||||
You can also opt to install `text-generation-inference` locally.
|
|
||||||
|
|
||||||
First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
|
||||||
Python 3.9, e.g. using `conda`:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
|
||||||
|
|
||||||
conda create -n text-generation-inference python=3.9
|
|
||||||
conda activate text-generation-inference
|
|
||||||
```
|
|
||||||
|
|
||||||
You may also need to install Protoc.
|
|
||||||
|
|
||||||
On Linux:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
|
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
|
|
||||||
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
|
|
||||||
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
|
|
||||||
rm -f $PROTOC_ZIP
|
|
||||||
```
|
|
||||||
|
|
||||||
On MacOS, using Homebrew:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
brew install protobuf
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
|
||||||
make run-falcon-7b-instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
sudo apt-get install libssl-dev gcc -y
|
|
||||||
```
|
|
||||||
|
|
||||||
### CUDA Kernels
|
|
||||||
|
|
||||||
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
|
||||||
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
|
||||||
|
|
||||||
Be aware that the official Docker image has them enabled by default.
|
|
||||||
|
|
||||||
## Run Falcon
|
|
||||||
|
|
||||||
### Run
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make run-falcon-7b-instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quantization
|
|
||||||
|
|
||||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make run-falcon-7b-instruct-quantize
|
|
||||||
```
|
|
||||||
|
|
||||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
|
||||||
|
|
||||||
## Develop
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make server-dev
|
|
||||||
make router-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# python
|
|
||||||
make python-server-tests
|
|
||||||
make python-client-tests
|
|
||||||
# or both server and client tests
|
|
||||||
make python-tests
|
|
||||||
# rust cargo tests
|
|
||||||
make rust-tests
|
|
||||||
# integration tests
|
|
||||||
make integration-tests
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Other supported hardware
|
|
||||||
|
|
||||||
TGI is also supported on the following AI hardware accelerators:
|
|
||||||
- *Habana first-gen Gaudi and Gaudi2:* checkout [here](https://github.com/huggingface/optimum-habana/tree/main/text-generation-inference) how to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
|
|
||||||
|
|
||||||
|
> The license to use TGI on Habana Gaudi is the one of TGI: https://github.com/huggingface/text-generation-inference/blob/main/LICENSE
|
||||||
|
>
|
||||||
|
> Please reach out to api-enterprise@huggingface.co if you have any question.
|
||||||
|
@ -459,7 +459,9 @@ fn shard_manager(
|
|||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
envs.push(("RANK".into(), rank.to_string().into()));
|
if world_size == 1 {
|
||||||
|
envs.push(("RANK".into(), rank.to_string().into()));
|
||||||
|
}
|
||||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
@ -870,7 +872,7 @@ fn spawn_shards(
|
|||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..1 {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
let revision = args.revision.clone();
|
let revision = args.revision.clone();
|
||||||
let uds_path = args.shard_uds_path.clone();
|
let uds_path = args.shard_uds_path.clone();
|
||||||
@ -921,12 +923,12 @@ fn spawn_shards(
|
|||||||
drop(shutdown_sender);
|
drop(shutdown_sender);
|
||||||
|
|
||||||
// Wait for shard to start
|
// Wait for shard to start
|
||||||
let mut shard_ready = 0;
|
let mut shard_ready = 0;
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
match status_receiver.try_recv() {
|
match status_receiver.try_recv() {
|
||||||
Ok(ShardStatus::Ready) => {
|
Ok(ShardStatus::Ready) => {
|
||||||
shard_ready += 1;
|
shard_ready += 1;
|
||||||
if shard_ready == num_shard {
|
if shard_ready == 1 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,11 +16,7 @@ gen-server:
|
|||||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation_server/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-torch:
|
install: gen-server
|
||||||
# Install specific version of torch
|
|
||||||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
|
|
||||||
|
|
||||||
install: gen-server install-torch
|
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install -e ".[bnb, accelerate]"
|
pip install -e ".[bnb, accelerate]"
|
||||||
@ -28,5 +24,12 @@ install: gen-server install-torch
|
|||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
|
install-poetry:
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
update-lock:
|
||||||
|
rm poetry.lock
|
||||||
|
poetry lock --no-update
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
poetry export -f requirements.txt --without-hashes --output requirements.txt
|
||||||
|
@ -9,52 +9,32 @@ text-generation-server = 'text_generation_server.cli:app'
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<3.13"
|
python = ">=3.9,<3.13"
|
||||||
protobuf = "^4.21.7"
|
protobuf = "^3.20.3"
|
||||||
grpcio = "^1.51.1"
|
grpcio = "^1.51.1"
|
||||||
grpcio-status = "^1.51.1"
|
grpcio-status = "*"
|
||||||
grpcio-reflection = "^1.51.1"
|
grpcio-reflection = "*"
|
||||||
grpc-interceptor = "^0.15.0"
|
grpc-interceptor = "^0.15.0"
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
accelerate = { version = "^0.20.0", optional = true }
|
safetensors = "0.3.2"
|
||||||
bitsandbytes = { version = "^0.41.1", optional = true }
|
|
||||||
safetensors = "^0.3.2"
|
|
||||||
loguru = "^0.6.0"
|
loguru = "^0.6.0"
|
||||||
opentelemetry-api = "^1.15.0"
|
opentelemetry-api = "^1.15.0"
|
||||||
opentelemetry-exporter-otlp = "^1.15.0"
|
opentelemetry-exporter-otlp = "^1.15.0"
|
||||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.13.3"
|
tokenizers = "^0.14.1"
|
||||||
huggingface-hub = "^0.16.4"
|
huggingface-hub = "^0.16.4"
|
||||||
transformers = "^4.32.1"
|
|
||||||
einops = "^0.6.1"
|
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
|
||||||
peft = "^0.4.0"
|
peft = "^0.4.0"
|
||||||
torch = { version = "^2.0.1" }
|
deepspeed = { git = "https://github.com/HabanaAI/DeepSpeed.git", branch = "1.13.0" }
|
||||||
scipy = "^1.11.1"
|
optimum-habana = { git = "https://github.com/huggingface/optimum-habana.git", branch = "main" }
|
||||||
pillow = "^10.0.0"
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
|
||||||
accelerate = ["accelerate"]
|
|
||||||
bnb = ["bitsandbytes"]
|
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
grpcio-tools = "^1.51.1"
|
grpcio-tools = "*"
|
||||||
pytest = "^7.3.0"
|
pytest = "^7.3.0"
|
||||||
|
|
||||||
|
|
||||||
[[tool.poetry.source]]
|
|
||||||
name = "pytorch-gpu-src"
|
|
||||||
url = "https://download.pytorch.org/whl/cu118"
|
|
||||||
priority = "explicit"
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = ["poetry-core>=1.0.0"]
|
||||||
"poetry-core>=1.0.0",
|
|
||||||
]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
@ -1,30 +1,32 @@
|
|||||||
accelerate==0.20.3 ; python_version >= "3.9" and python_version < "3.13"
|
accelerate>=0.22.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
|
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
datasets==2.14.5 ; python_version >= "3.9" and python_version < "3.13"
|
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
datasets==2.14.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
diffusers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
|
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.12.4 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.57.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
importlib-metadata==6.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -32,7 +34,7 @@ mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
|
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
|
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.0 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -42,34 +44,35 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
|
|||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
optimum==1.13.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pandas==2.1.1 ; python_version >= "3.9" and python_version < "3.13"
|
pandas==2.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
|
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyarrow==13.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
pyarrow==13.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
|
||||||
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.13"
|
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "3.13"
|
pytz==2023.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2023.8.8 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2023.8.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.11.2 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==68.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==68.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.34.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
transformers[sentencepiece]==4.34.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.13"
|
tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.0.5 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
zipp==3.16.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
@ -14,11 +13,7 @@ app = typer.Typer()
|
|||||||
|
|
||||||
class Quantization(str, Enum):
|
class Quantization(str, Enum):
|
||||||
bitsandbytes = "bitsandbytes"
|
bitsandbytes = "bitsandbytes"
|
||||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
|
||||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
awq = "awq"
|
|
||||||
eetq = "eetq"
|
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
@ -40,18 +35,9 @@ def serve(
|
|||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert (
|
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
|
||||||
os.getenv("RANK", None) is not None
|
assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True"
|
||||||
), "RANK must be set when sharded is True"
|
assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True"
|
||||||
assert (
|
|
||||||
os.getenv("WORLD_SIZE", None) is not None
|
|
||||||
), "WORLD_SIZE must be set when sharded is True"
|
|
||||||
assert (
|
|
||||||
os.getenv("MASTER_ADDR", None) is not None
|
|
||||||
), "MASTER_ADDR must be set when sharded is True"
|
|
||||||
assert (
|
|
||||||
os.getenv("MASTER_PORT", None) is not None
|
|
||||||
), "MASTER_PORT must be set when sharded is True"
|
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -75,14 +61,29 @@ def serve(
|
|||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||||||
if dtype is not None and quantize is not None:
|
|
||||||
raise RuntimeError(
|
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
|
||||||
)
|
if sharded:
|
||||||
server.serve(
|
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
)
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file} --model_id {model_id} --revision {revision} --sharded {sharded} --dtype {dtype} --uds_path {uds_path}"
|
||||||
|
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||||
|
proc.wait()
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
if proc.returncode != 0:
|
||||||
|
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||||
|
return proc.returncode
|
||||||
|
else:
|
||||||
|
server.serve(model_id, revision, dtype, uds_path, sharded)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@ -93,7 +94,6 @@ def download_weights(
|
|||||||
auto_convert: bool = True,
|
auto_convert: bool = True,
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -124,19 +124,6 @@ def download_weights(
|
|||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
|
||||||
adapter_config_filename = hf_hub_download(
|
|
||||||
model_id, revision=revision, filename="adapter_config.json"
|
|
||||||
)
|
|
||||||
utils.download_and_unload_peft(
|
|
||||||
model_id, revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
is_local_model = True
|
|
||||||
utils.weight_files(model_id, revision, extension)
|
|
||||||
return
|
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try to download weights from the hub
|
# Try to download weights from the hub
|
||||||
try:
|
try:
|
||||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
@ -175,30 +162,24 @@ def download_weights(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Safetensors final filenames
|
# Safetensors final filenames
|
||||||
local_st_files = [
|
local_st_files = [p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files]
|
||||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
|
||||||
for p in local_pt_files
|
|
||||||
]
|
|
||||||
try:
|
try:
|
||||||
import transformers
|
import transformers
|
||||||
import json
|
from transformers import AutoConfig
|
||||||
|
|
||||||
if is_local_model:
|
config = AutoConfig.from_pretrained(
|
||||||
config_filename = os.path.join(model_id, "config.json")
|
model_id,
|
||||||
else:
|
revision=revision,
|
||||||
config_filename = hf_hub_download(
|
)
|
||||||
model_id, revision=revision, filename="config.json"
|
architecture = config.architectures[0]
|
||||||
)
|
|
||||||
with open(config_filename, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
architecture = config["architectures"][0]
|
|
||||||
|
|
||||||
class_ = getattr(transformers, architecture)
|
class_ = getattr(transformers, architecture)
|
||||||
|
|
||||||
# Name for this varible depends on transformers version.
|
# Name for this varible depends on transformers version.
|
||||||
discard_names = getattr(class_, "_tied_weights_keys", [])
|
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||||
|
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
discard_names = []
|
discard_names = []
|
||||||
# Convert pytorch weights to safetensors
|
# Convert pytorch weights to safetensors
|
||||||
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||||
@ -216,8 +197,6 @@ def quantize(
|
|||||||
percdamp: float = 0.01,
|
percdamp: float = 0.01,
|
||||||
act_order: bool = False,
|
act_order: bool = False,
|
||||||
):
|
):
|
||||||
if revision is None:
|
|
||||||
revision = "main"
|
|
||||||
download_weights(
|
download_weights(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -231,7 +210,6 @@ def quantize(
|
|||||||
bits=4,
|
bits=4,
|
||||||
groupsize=128,
|
groupsize=128,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
upload_to_model_id=upload_to_model_id,
|
upload_to_model_id=upload_to_model_id,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
|
@ -1,336 +1,35 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
|
from transformers import AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
|
||||||
from text_generation_server.models.rw import RW
|
|
||||||
from text_generation_server.models.opt import OPTSharded
|
|
||||||
from text_generation_server.models.galactica import GalacticaSharded
|
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
||||||
# in PyTorch 1.12 and later.
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Model",
|
|
||||||
"BLOOMSharded",
|
|
||||||
"CausalLM",
|
|
||||||
"FlashCausalLM",
|
|
||||||
"GalacticaSharded",
|
|
||||||
"Seq2SeqLM",
|
|
||||||
"SantaCoder",
|
|
||||||
"OPTSharded",
|
|
||||||
"T5Sharded",
|
|
||||||
"get_model",
|
|
||||||
]
|
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
|
||||||
try:
|
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
|
||||||
from text_generation_server.models.flash_llama import (
|
|
||||||
FlashLlama,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.flash_santacoder import (
|
|
||||||
FlashSantacoderSharded,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
|
||||||
FLASH_ATTENTION = False
|
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
__all__.append(FlashNeoXSharded)
|
|
||||||
__all__.append(FlashRWSharded)
|
|
||||||
__all__.append(FlashSantacoderSharded)
|
|
||||||
__all__.append(FlashLlama)
|
|
||||||
__all__.append(IDEFICSSharded)
|
|
||||||
|
|
||||||
MISTRAL = True
|
|
||||||
try:
|
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"Could not import Mistral model: {e}")
|
|
||||||
MISTRAL = False
|
|
||||||
|
|
||||||
if MISTRAL:
|
|
||||||
__all__.append(FlashMistral)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
dtype: Optional[torch.dtype] = None,
|
||||||
quantize: Optional[str],
|
|
||||||
dtype: Optional[str],
|
|
||||||
trust_remote_code: bool,
|
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
dtype = torch.float16
|
model_type = config.model_type
|
||||||
elif dtype == "float16":
|
|
||||||
dtype = torch.float16
|
|
||||||
elif dtype == "bfloat16":
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
|
||||||
return GalacticaSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_id.startswith("bigcode/"):
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
return FlashSantacoderSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif sharded:
|
|
||||||
raise NotImplementedError(
|
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return SantaCoder(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
model_type = config_dict["model_type"]
|
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if FLASH_ATTENTION:
|
return SantaCoder(model_id, revision, dtype)
|
||||||
return FlashSantacoderSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif sharded:
|
|
||||||
raise NotImplementedError(
|
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return SantaCoder(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
return BLOOMSharded(
|
return BLOOM(model_id, revision, dtype)
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif model_type == "mpt":
|
|
||||||
return MPTSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "gpt_neox":
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
return FlashNeoXSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif sharded:
|
|
||||||
return GPTNeoxSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return CausalLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "llama" or model_type == "baichuan":
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
return FlashLlama(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif sharded:
|
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
|
||||||
else:
|
|
||||||
return CausalLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
|
||||||
if sharded:
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
if config_dict.get("alibi", False):
|
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
|
||||||
return FlashRWSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
|
||||||
else:
|
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
|
||||||
return FlashRWSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return RW(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "mistral":
|
|
||||||
if MISTRAL:
|
|
||||||
return FlashMistral(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
|
||||||
|
|
||||||
if model_type == "opt":
|
|
||||||
return OPTSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "t5":
|
|
||||||
return T5Sharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
if model_type == "idefics":
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
return IDEFICSSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
|
||||||
|
|
||||||
if sharded:
|
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
|
||||||
if quantize == "gptq":
|
|
||||||
raise ValueError(
|
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
||||||
)
|
|
||||||
if quantize == "awq":
|
|
||||||
raise ValueError("awq quantization is not supported for AutoModel")
|
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
|
||||||
raise ValueError("4bit quantization is not supported for AutoModel")
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(model_id, revision, dtype)
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
return Seq2SeqLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
auto_map = config_dict.get("auto_map", None)
|
|
||||||
if trust_remote_code and auto_map is not None:
|
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
|
||||||
return CausalLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
|
||||||
return Seq2SeqLM(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model type {model_type}")
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
@ -1,25 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from transformers import (
|
from transformers import PreTrainedTokenizerBase
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
PreTrainedTokenizerBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|
||||||
BloomForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
@ -30,82 +17,32 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
is_optimized_for_gaudi: bool = False,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
batch = super().from_pb(
|
||||||
|
pb=pb,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
is_optimized_for_gaudi=is_optimized_for_gaudi,
|
||||||
|
)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class BLOOMSharded(CausalLM):
|
class BLOOM(CausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
super(BLOOM, self).__init__(
|
||||||
if torch.cuda.is_available():
|
model_id=model_id,
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
slow_but_exact=False,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.pad_token_id = 3
|
|
||||||
config.quantize = quantize
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize == "gptq":
|
|
||||||
weights._set_gptq_params(model_id)
|
|
||||||
|
|
||||||
model = BloomForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return BloomCausalLMBatch
|
return BloomCausalLMBatch
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = outputs.logits
|
|
||||||
return logits, outputs.past_key_values
|
|
||||||
|
@ -1,11 +1,24 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from optimum.habana.utils import HabanaProfile
|
||||||
|
|
||||||
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
|
from optimum.habana.checkpoint_utils import (
|
||||||
|
get_repo_root,
|
||||||
|
model_on_meta,
|
||||||
|
write_checkpoints_json,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -16,11 +29,11 @@ from text_generation_server.models.types import (
|
|||||||
TopTokens,
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatch(Batch):
|
class CausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
@ -42,7 +55,7 @@ class CausalLMBatch(Batch):
|
|||||||
read_offsets: List[int]
|
read_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
top_n_tokens_tensor: torch.Tensor
|
top_n_tokens_tensor: torch.Tensor
|
||||||
@ -72,66 +85,90 @@ class CausalLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
is_optimized_for_gaudi: bool = False,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
|
|
||||||
|
# TODO: this should be set to rust side `max_total_tokens`,
|
||||||
|
# (see https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs#L177)
|
||||||
|
# but TGI does not offer an API to expose this variable to python, as this variable
|
||||||
|
# is handled by the client but it appears the model is initialized by the server.
|
||||||
|
# An alternative could be to initialize the buffers during warmup.
|
||||||
|
# Dummy
|
||||||
|
max_total_tokens = int(os.getenv("MAX_TOTAL_TOKENS", "0"))
|
||||||
|
logger.info("MAX_TOTAL_TOKENS = {}".format(max_total_tokens))
|
||||||
|
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
r.stopping_parameters, tokenizer
|
|
||||||
)
|
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(padding_right_offset, stopping_criteria.max_new_tokens)
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
|
||||||
)
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding="max_length",
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
input_lengths.append(input_len)
|
||||||
prefix_offsets.append(input_len - 5)
|
prefix_offsets.append(input_len - 5)
|
||||||
read_offsets.append(input_len)
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
max_input_length = max(input_lengths)
|
||||||
max_input_length = input_lengths.max()
|
if max_total_tokens == 0:
|
||||||
|
max_total_tokens = max_input_length
|
||||||
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
|
if is_optimized_for_gaudi and max_total_tokens > max_input_length:
|
||||||
|
# pad to max_total_tokens in case max_new_token changes per request and triggers new hpu graph generation
|
||||||
|
padding_right_offset = max_total_tokens - max_input_length
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
# Allocate maximum attention_mask
|
attention_mask = tokenized_inputs["attention_mask"]
|
||||||
attention_mask = input_ids.new_zeros(
|
# only move model inputs to device
|
||||||
(pb.size, max_input_length + padding_right_offset)
|
attention_mask = attention_mask.to(device)
|
||||||
)
|
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
|
||||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
if is_optimized_for_gaudi:
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
input_ids_cpu = torch.nn.functional.pad(
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
input_ids, (0, padding_right_offset), value=tokenizer.pad_token_id
|
||||||
top_n_tokens_tensor = torch.tensor(
|
)
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
input_ids = input_ids_cpu.to(device)
|
||||||
)
|
attention_mask = torch.nn.functional.pad(attention_mask, (0, padding_right_offset), value=0)
|
||||||
|
all_input_ids = input_ids_cpu.T.split(1, dim=1)
|
||||||
|
else:
|
||||||
|
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -142,20 +179,20 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
@ -172,7 +209,6 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
@ -193,52 +229,66 @@ class CausalLMBatch(Batch):
|
|||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
max_input_length = max(max_input_length, request_input_length)
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
remaining_decode_tokens = (
|
remaining_decode_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
|
||||||
)
|
|
||||||
total_remaining_decode_tokens += remaining_decode_tokens
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
new_padding_right_offset = max(
|
new_padding_right_offset = max(new_padding_right_offset, remaining_decode_tokens)
|
||||||
new_padding_right_offset, remaining_decode_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
input_ids = self.input_ids[keep_indices]
|
input_ids = self.input_ids[keep_indices]
|
||||||
position_ids = self.position_ids[keep_indices]
|
position_ids = self.position_ids[keep_indices]
|
||||||
self.attention_mask = self.attention_mask[
|
next_token_chooser = self.next_token_chooser.filter(keep_indices)
|
||||||
keep_indices,
|
if is_optimized_for_gaudi:
|
||||||
-(self.padding_right_offset + max_input_length) : (
|
self.attention_mask = self.attention_mask[keep_indices]
|
||||||
self.attention_mask.shape[1] - self.padding_right_offset
|
else:
|
||||||
)
|
self.attention_mask = self.attention_mask[
|
||||||
+ new_padding_right_offset,
|
keep_indices,
|
||||||
]
|
-(self.padding_right_offset + max_input_length) : (
|
||||||
|
self.attention_mask.shape[1] - self.padding_right_offset
|
||||||
|
)
|
||||||
|
+ new_padding_right_offset,
|
||||||
|
]
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
# Ensure that past_key_values tensors can be updated in-place
|
||||||
|
kv_tuple = False
|
||||||
if type(self.past_key_values[0]) == tuple:
|
if type(self.past_key_values[0]) == tuple:
|
||||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||||
|
kv_tuple = True
|
||||||
|
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
past_kv_length = max_input_length - 1
|
past_kv_length = max_input_length - 1
|
||||||
for layer in self.past_key_values:
|
for layer in self.past_key_values:
|
||||||
past_keys, past_values = layer
|
past_keys, past_values = layer
|
||||||
if len(past_keys.shape) == 3:
|
past_keys_dims = len(past_keys.shape)
|
||||||
|
if past_keys_dims == 3:
|
||||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
||||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
||||||
if self.keys_head_dim_last:
|
if is_optimized_for_gaudi:
|
||||||
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
|
layer[0] = past_keys[keep_indices]
|
||||||
|
del past_keys
|
||||||
|
layer[1] = past_values[keep_indices]
|
||||||
|
del past_values
|
||||||
else:
|
else:
|
||||||
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
|
if self.keys_head_dim_last:
|
||||||
del past_keys
|
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
|
||||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
else:
|
||||||
del past_values
|
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
|
||||||
|
del past_keys
|
||||||
|
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||||
|
del past_values
|
||||||
|
if past_keys_dims == 3:
|
||||||
|
layer[0] = layer[0].view(layer[0].shape[0] * layer[0].shape[1], *layer[0].shape[-2:])
|
||||||
|
layer[1] = layer[1].view(layer[1].shape[0] * layer[1].shape[1], *layer[1].shape[-2:])
|
||||||
|
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
|
if kv_tuple:
|
||||||
|
self.past_key_values = [tuple(layer) for layer in self.past_key_values]
|
||||||
|
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
@ -247,7 +297,7 @@ class CausalLMBatch(Batch):
|
|||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.prefix_offsets = prefix_offsets
|
self.prefix_offsets = prefix_offsets
|
||||||
self.read_offsets = read_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_chooser = next_token_chooser
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.top_n_tokens = top_n_tokens
|
self.top_n_tokens = top_n_tokens
|
||||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
@ -259,15 +309,20 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
|
max_total_tokens = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += len(batch)
|
total_batch_size += len(batch)
|
||||||
max_input_length = max(max_input_length, batch.max_input_length)
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
max_total_tokens = max(max_total_tokens, batch.max_input_length + batch.padding_right_offset)
|
||||||
|
|
||||||
|
if is_optimized_for_gaudi and max_total_tokens > max_input_length:
|
||||||
|
padding_right_offset = max_total_tokens - max_input_length
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
@ -276,7 +331,7 @@ class CausalLMBatch(Batch):
|
|||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
@ -297,7 +352,7 @@ class CausalLMBatch(Batch):
|
|||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
top_n_tokens.extend(batch.top_n_tokens)
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
@ -338,15 +393,8 @@ class CausalLMBatch(Batch):
|
|||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
left_offset = max_input_length - batch.max_input_length
|
left_offset = max_input_length - batch.max_input_length
|
||||||
batch_left_offset = (
|
batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset
|
||||||
batch.attention_mask.shape[1]
|
attention_mask[start_index:end_index, left_offset:-padding_right_offset] = batch.attention_mask[
|
||||||
- batch.max_input_length
|
|
||||||
- batch.padding_right_offset
|
|
||||||
)
|
|
||||||
attention_mask[
|
|
||||||
start_index:end_index,
|
|
||||||
left_offset:-padding_right_offset,
|
|
||||||
] = batch.attention_mask[
|
|
||||||
:,
|
:,
|
||||||
batch_left_offset : -batch.padding_right_offset,
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
]
|
]
|
||||||
@ -361,30 +409,38 @@ class CausalLMBatch(Batch):
|
|||||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||||
# And ensure that we can update tensors in-place
|
# And ensure that we can update tensors in-place
|
||||||
|
kv_tuple = False
|
||||||
|
past_key_values_dims = len(batch.past_key_values[0][0].shape)
|
||||||
if type(batch.past_key_values[0]) == tuple:
|
if type(batch.past_key_values[0]) == tuple:
|
||||||
batch.past_key_values = [
|
batch.past_key_values = [
|
||||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
|
||||||
for layer in batch.past_key_values
|
|
||||||
]
|
]
|
||||||
elif len(batch.past_key_values[0][0].shape) == 3:
|
kv_tuple = True
|
||||||
|
elif past_key_values_dims == 3:
|
||||||
for layer in batch.past_key_values:
|
for layer in batch.past_key_values:
|
||||||
for k, t in enumerate(layer):
|
for k, t in enumerate(layer):
|
||||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||||
|
|
||||||
# Add eventual padding tokens that were added while concatenating
|
# Add eventual padding tokens that were added while concatenating
|
||||||
max_tokens += batch.max_tokens + (
|
max_tokens += batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch)
|
||||||
max_input_length - batch.max_input_length
|
|
||||||
) * len(batch)
|
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
first_past_kvs = batches[0].past_key_values
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
next_token_chooser_parameters,
|
||||||
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
|
device=batches[0].next_token_chooser.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
first_past_kvs = batches[0].past_key_values
|
||||||
|
_, num_heads, _, head_dim = first_past_kvs[0][1].shape
|
||||||
|
padded_sequence_length = (
|
||||||
|
max_input_length + padding_right_offset if is_optimized_for_gaudi else max_input_length - 1
|
||||||
|
)
|
||||||
padded_past_values_shape = (
|
padded_past_values_shape = (
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
max_input_length - 1,
|
padded_sequence_length,
|
||||||
head_dim,
|
head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -396,7 +452,7 @@ class CausalLMBatch(Batch):
|
|||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
max_input_length - 1,
|
padded_sequence_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Iterate over attention layers
|
# Iterate over attention layers
|
||||||
@ -413,22 +469,24 @@ class CausalLMBatch(Batch):
|
|||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
|
# recaculate the offset
|
||||||
|
left_offset = max_input_length - batch.max_input_length
|
||||||
|
batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset
|
||||||
|
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[
|
padded_past_keys[
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
start_index:end_index, :, left_offset : left_offset + past_seq_len, :
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
] = past_keys[:, :, batch_left_offset : batch_left_offset + past_seq_len, :]
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
# BLOOM case
|
||||||
padded_past_keys[
|
padded_past_keys[
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
start_index:end_index, :, :, left_offset : left_offset + past_seq_len
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
] = past_keys[:, :, :, batch_left_offset : batch_left_offset + past_seq_len]
|
||||||
del past_keys
|
del past_keys
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
|
||||||
padded_past_values_shape
|
|
||||||
)
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
past_values = batch.past_key_values[j][1]
|
past_values = batch.past_key_values[j][1]
|
||||||
@ -439,15 +497,30 @@ class CausalLMBatch(Batch):
|
|||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
|
# recaculate the offset
|
||||||
|
left_offset = max_input_length - batch.max_input_length
|
||||||
|
batch_left_offset = batch.attention_mask.shape[1] - batch.max_input_length - batch.padding_right_offset
|
||||||
|
|
||||||
padded_past_values[
|
padded_past_values[
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
start_index:end_index, :, left_offset : left_offset + past_seq_len, :
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
] = past_values[:, :, batch_left_offset : batch_left_offset + past_seq_len, :]
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
past_key_values.append([padded_past_keys, padded_past_values])
|
if past_key_values_dims == 3:
|
||||||
|
padded_past_keys = padded_past_keys.view(
|
||||||
|
padded_past_keys.shape[0] * padded_past_keys.shape[1], *padded_past_keys.shape[-2:]
|
||||||
|
)
|
||||||
|
padded_past_values = padded_past_values.view(
|
||||||
|
padded_past_values.shape[0] * padded_past_values.shape[1], *padded_past_values.shape[-2:]
|
||||||
|
)
|
||||||
|
|
||||||
|
if kv_tuple:
|
||||||
|
past_key_values.append((padded_past_keys, padded_past_values))
|
||||||
|
else:
|
||||||
|
past_key_values.append([padded_past_keys, padded_past_values])
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
@ -461,7 +534,7 @@ class CausalLMBatch(Batch):
|
|||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
@ -480,39 +553,88 @@ class CausalLM(Model):
|
|||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
device = torch.device("hpu")
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
model_kwargs = {
|
||||||
revision=revision,
|
"revision": revision,
|
||||||
torch_dtype=dtype,
|
}
|
||||||
device_map="auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
else None,
|
rank = int(os.getenv("RANK"), 0)
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||||
trust_remote_code=trust_remote_code,
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
)
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if world_size > 1:
|
||||||
model = model.cuda()
|
import habana_frameworks.torch.hpu as torch_hpu
|
||||||
|
|
||||||
|
# Get world size, rank and local rank
|
||||||
|
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
|
||||||
|
|
||||||
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
# Initialize process(es) for DeepSpeed
|
||||||
|
deepspeed.init_distributed(dist_backend="hccl")
|
||||||
|
logger.info(
|
||||||
|
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
|
||||||
|
)
|
||||||
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||||
|
load_to_meta = model_on_meta(config)
|
||||||
|
|
||||||
|
if load_to_meta:
|
||||||
|
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
||||||
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||||
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
|
||||||
|
else:
|
||||||
|
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
|
||||||
|
# TODO: revisit placement on CPU when auto-injection is possible
|
||||||
|
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
ds_inference_kwargs = {"dtype": dtype}
|
||||||
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||||
|
ds_inference_kwargs["enable_cuda_graph"] = self.enable_hpu_graph
|
||||||
|
|
||||||
|
if load_to_meta:
|
||||||
|
# model loaded to meta is managed differently
|
||||||
|
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||||
|
write_checkpoints_json(model_id, local_rank, checkpoints_json)
|
||||||
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
|
model = model.module
|
||||||
|
else:
|
||||||
|
get_repo_root(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
model = model.eval().to(device)
|
||||||
|
#wrap in hpu_graph only if self.enable_hpu_graph is set
|
||||||
|
if self.enable_hpu_graph:
|
||||||
|
model = wrap_in_hpu_graph(model)
|
||||||
|
|
||||||
|
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
|
self.is_optimized_for_gaudi = True
|
||||||
|
else:
|
||||||
|
self.is_optimized_for_gaudi = False
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
@ -524,64 +646,132 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"use_cache": True,
|
||||||
|
"return_dict": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.config.model_type == "llama":
|
||||||
|
kwargs["attn_softmax_bf16"] = True
|
||||||
|
kwargs["trim_logits"] = True
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0"))
|
||||||
|
self.profiling_steps = int(os.getenv("PROF_STEP", "5"))
|
||||||
|
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
|
||||||
|
self.hb_profer = HabanaProfile(
|
||||||
|
warmup=self.profiling_warmup_steps, active=self.profiling_steps, output_dir=output_dir
|
||||||
|
)
|
||||||
|
if self.profiling_warmup_steps > 0:
|
||||||
|
self.hb_profer_started = True
|
||||||
|
self.hb_profer.start()
|
||||||
|
else:
|
||||||
|
self.hb_profer = None
|
||||||
|
self.hb_profer_started = False
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
token_idx: Optional = None,
|
||||||
|
past_key_values: Optional = None,
|
||||||
|
bypass_hpu_graph: Optional = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": True,
|
|
||||||
"return_dict": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.is_optimized_for_gaudi:
|
||||||
|
kwargs["token_idx"] = token_idx
|
||||||
|
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
|
if bypass_hpu_graph != None:
|
||||||
|
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||||
|
|
||||||
|
kwargs.update(self.kwargs)
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
self, batch: CausalLMBatch
|
self.step = self.step + 1
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||||
# slice the attention mask to the correct shape
|
self.hb_profer.stop()
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
self.hb_profer_started = False
|
||||||
|
|
||||||
|
if self.is_optimized_for_gaudi:
|
||||||
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.padding_right_offset).to(self.device)
|
||||||
|
attention_mask = batch.attention_mask
|
||||||
|
|
||||||
|
else:
|
||||||
|
token_idx = None
|
||||||
|
# slice the attention mask to the correct shape
|
||||||
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
prefill = batch.past_key_values is None
|
||||||
|
if batch.past_key_values:
|
||||||
|
if token_idx is not None:
|
||||||
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
|
token_idx,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
|
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
|
# Select next token
|
||||||
|
input_length = batch.input_lengths[0]
|
||||||
|
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||||
|
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
|
batch.input_ids[:, :token_idx], logits[:, input_length - 1 : input_length, :].squeeze(-2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
|
batch.input_ids[:, :token_idx], logits.squeeze(-2)
|
||||||
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
batch.top_n_tokens_tensor,
|
batch.top_n_tokens_tensor,
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
htorch.core.mark_step()
|
||||||
|
logits = logits.to("cpu")
|
||||||
|
|
||||||
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
|
next_token_ids = next_input_ids
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
@ -589,14 +779,16 @@ class CausalLM(Model):
|
|||||||
batch.prefix_offsets,
|
batch.prefix_offsets,
|
||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_chooser.do_sample,
|
||||||
|
batch.next_token_chooser.seeds,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
next_token_ids,
|
||||||
|
next_token_logprobs,
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
batch_top_token_logprobs,
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
@ -604,32 +796,31 @@ class CausalLM(Model):
|
|||||||
prefix_offset,
|
prefix_offset,
|
||||||
read_offset,
|
read_offset,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
do_sample,
|
||||||
|
seed,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
next_token_id,
|
||||||
|
next_token_logprob,
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
|
||||||
next_token_id, logprobs = next_token_chooser(
|
|
||||||
all_input_ids.view(1, -1), logits[-1:, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
if self.is_optimized_for_gaudi:
|
||||||
|
all_input_ids[input_length] = next_token_id
|
||||||
|
else:
|
||||||
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids[:, 0], prefix_offset, read_offset
|
all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token_id_squeezed,
|
next_token_id,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -641,23 +832,14 @@ class CausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text, _, _ = self.decode_token(
|
output_text = self.decode(
|
||||||
all_input_ids[:, 0],
|
all_input_ids[new_input_length - stopping_criteria.current_tokens : new_input_length, 0]
|
||||||
prefix_offset=len(all_input_ids)
|
|
||||||
- stopping_criteria.current_tokens
|
|
||||||
- 1,
|
|
||||||
read_offset=len(all_input_ids)
|
|
||||||
- stopping_criteria.current_tokens,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
)
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
output_text,
|
||||||
|
stopping_criteria.current_tokens,
|
||||||
|
reason,
|
||||||
|
seed if do_sample else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generated_text = None
|
generated_text = None
|
||||||
@ -665,20 +847,14 @@ class CausalLM(Model):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
prefill_logprobs = [float("nan")] + next_token_logprobs
|
||||||
logits, -1
|
prefill_token_ids = all_input_ids[0 : new_input_length - 1]
|
||||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
|
||||||
-new_input_length:-1
|
|
||||||
].tolist()
|
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(prefill_token_ids, prefill_logprobs, prefill_texts)
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
@ -688,9 +864,7 @@ class CausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
special_toptokens = [
|
special_toptokens = [token_id in self.all_special_ids for token_id in top_token_ids]
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
|
||||||
]
|
|
||||||
top_tokens = TopTokens(
|
top_tokens = TopTokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
@ -703,40 +877,53 @@ class CausalLM(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id_squeezed,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
|
||||||
batch.input_ids[i, 0] = next_token_id
|
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
|
next_tokens = torch.tensor(next_token_ids, dtype=torch.int64).to(self.device)
|
||||||
|
if token_idx is None:
|
||||||
|
batch.input_ids[:, 0] = next_tokens[:, 0]
|
||||||
|
else:
|
||||||
|
batch.input_ids[:, token_idx] = next_tokens
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
|
if self.hb_profer_started == True:
|
||||||
|
self.hb_profer.step()
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill, use it to store next token
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
if token_idx is None:
|
||||||
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
if self.is_optimized_for_gaudi:
|
||||||
|
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||||
|
else:
|
||||||
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
# Decrease right offset
|
# Decrease right offset
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
if prefill:
|
||||||
|
batch.position_ids = batch.position_ids[:, token_idx - 1 : token_idx] + 1
|
||||||
|
else:
|
||||||
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
if self.hb_profer_started == True:
|
||||||
|
self.hb_profer.step()
|
||||||
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
@ -2,10 +2,10 @@ import inspect
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Optional, Tuple, Type, TypeVar
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
@ -21,9 +21,9 @@ class Model(ABC):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
kwargs: dict = {},
|
||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
@ -31,25 +31,17 @@ class Model(ABC):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sliding_window = sliding_window
|
self.kwargs = kwargs
|
||||||
|
self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None
|
||||||
self.has_position_ids = (
|
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
|
||||||
is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
if self.requires_padding and self.sliding_window is not None:
|
|
||||||
raise NotImplementedError("sliding_window is not implemented with padding")
|
|
||||||
|
|
||||||
return InfoResponse(
|
return InfoResponse(
|
||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -58,31 +50,24 @@ class Model(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(self, batch: B, max_total_tokens: int):
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None
|
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
prefix_offset: int = 0,
|
prefix_offset: int = 0,
|
||||||
read_offset: int = 0,
|
read_offset: int = 0,
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
) -> Tuple[str, int, int]:
|
) -> Tuple[str, int, int]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
# which decide to add a space or not depending on the surrounding ids.
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
prefix_text = self.tokenizer.decode(
|
prefix_text = self.tokenizer.decode(all_input_ids[prefix_offset:read_offset], skip_special_tokens=False)
|
||||||
all_input_ids[prefix_offset:read_offset],
|
new_text = self.tokenizer.decode(all_input_ids[prefix_offset:], skip_special_tokens=False)
|
||||||
skip_special_tokens=skip_special_tokens,
|
|
||||||
)
|
|
||||||
new_text = self.tokenizer.decode(
|
|
||||||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
import torch
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
@ -18,28 +15,11 @@ class SantaCoder(CausalLM):
|
|||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
super().__init__(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
self.tokenizer.add_special_tokens(
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
{
|
{
|
||||||
"additional_special_tokens": [
|
"additional_special_tokens": [
|
||||||
EOD,
|
EOD,
|
||||||
@ -51,25 +31,7 @@ class SantaCoder(CausalLM):
|
|||||||
"pad_token": EOD,
|
"pad_token": EOD,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
with device:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
@ -14,7 +15,6 @@ from text_generation_server.interceptor import ExceptionInterceptor
|
|||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
@ -23,16 +23,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
if model.device.type == "cuda":
|
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
# op not optimized issue. Will investigate further.
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
# if model.device.type == "hpu":
|
||||||
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
|
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
async def Health(self, request, context):
|
async def Health(self, request, context):
|
||||||
if self.model.device.type == "cuda":
|
if self.model.device.type == "hpu":
|
||||||
torch.zeros((2, 2)).cuda()
|
torch.zeros((2, 2)).to("hpu")
|
||||||
return generate_pb2.HealthResponse()
|
return generate_pb2.HealthResponse()
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
@ -49,47 +51,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
if (
|
# batch = self.model.batch_type.from_pb(
|
||||||
self.model.batch_type == IdeficsCausalLMBatch
|
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
# )
|
||||||
batch = self.model.batch_type.from_pb(
|
# max_supported_total_tokens = self.model.warmup(batch)
|
||||||
request.batch,
|
|
||||||
self.model.tokenizer,
|
|
||||||
self.model.processor,
|
|
||||||
self.model.dtype,
|
|
||||||
self.model.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch = self.model.batch_type.from_pb(
|
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
|
||||||
)
|
|
||||||
max_supported_total_tokens = self.model.warmup(batch)
|
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse(
|
# return generate_pb2.WarmupResponse(
|
||||||
max_supported_total_tokens=max_supported_total_tokens
|
# max_supported_total_tokens=max_supported_total_tokens
|
||||||
)
|
# )
|
||||||
|
logger.warning("Warmup is not enabled on HPU.")
|
||||||
|
return generate_pb2.WarmupResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
if (
|
batch = self.model.batch_type.from_pb(
|
||||||
self.model.batch_type == IdeficsCausalLMBatch
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
)
|
||||||
batch = self.model.batch_type.from_pb(
|
|
||||||
request.batch,
|
|
||||||
self.model.tokenizer,
|
|
||||||
self.model.processor,
|
|
||||||
self.model.dtype,
|
|
||||||
self.model.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch = self.model.batch_type.from_pb(
|
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
|
||||||
)
|
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
@ -114,7 +96,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
@ -130,54 +112,53 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
|
||||||
quantize: Optional[str],
|
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
|
sharded: bool,
|
||||||
):
|
):
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level="INFO",
|
||||||
|
serialize=False,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
sharded: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
|
logger.info("Server:server_inner: sharded ={}".format(sharded))
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
logger.info("Server:server_inner: rank ={}".format(rank))
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(uds_path, rank)
|
unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||||
for rank in range(int(os.environ["WORLD_SIZE"]))
|
|
||||||
]
|
]
|
||||||
local_url = server_urls[int(os.environ["RANK"])]
|
local_url = server_urls[int(os.environ["RANK"])]
|
||||||
else:
|
else:
|
||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
|
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url))
|
||||||
|
if dtype == "bfloat16" or None:
|
||||||
|
data_type = torch.bfloat16
|
||||||
|
else:
|
||||||
|
data_type = torch.float
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(model_id, revision=revision, dtype=data_type)
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if quantize == "gptq":
|
|
||||||
try:
|
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
|
||||||
# For which we have the finale shapes only after the model has loaded
|
|
||||||
# This will allocate those buffers.
|
|
||||||
from text_generation_server.utils.gptq.exllama import (
|
|
||||||
create_exllama_buffers,
|
|
||||||
set_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_device(model.device)
|
|
||||||
create_exllama_buffers()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(),
|
||||||
@ -204,6 +185,9 @@ def serve(
|
|||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
logger.info(
|
||||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
|
||||||
|
model_id, revision, dtype, sharded
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
asyncio.run(serve_inner(model_id, revision, dtype, sharded))
|
||||||
|
29
server/text_generation_server/tgi_service.py
Normal file
29
server/text_generation_server/tgi_service.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
from text_generation_server import server
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
logger.info("TGIService: starting tgi service .... ")
|
||||||
|
logger.info(
|
||||||
|
"TGIService: --model_id {}, --revision {}, --sharded {}, --dtype {}, --uds_path {} ".format(
|
||||||
|
args.model_id, args.revision, args.sharded, args.dtype, args.uds_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server.serve(
|
||||||
|
model_id=args.model_id, revision=args.revision, dtype=args.dtype, uds_path=args.uds_path, sharded=args.sharded
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_id", type=str)
|
||||||
|
parser.add_argument("--revision", type=str)
|
||||||
|
parser.add_argument("--sharded", type=bool)
|
||||||
|
parser.add_argument("--dtype", type=str)
|
||||||
|
parser.add_argument("--uds_path", type=Path)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -44,6 +44,12 @@ class FakeGroup:
|
|||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
|
import habana_frameworks.torch.core as htcore
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK", "0"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
options = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
@ -56,9 +62,13 @@ def initialize_torch_distributed():
|
|||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
|
elif torch.hpu.is_available():
|
||||||
|
backend = "hccl"
|
||||||
|
n_hpus = torch.hpu.device_count()
|
||||||
|
if world_size > n_hpus:
|
||||||
|
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
|
||||||
else:
|
else:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
import habana_frameworks.torch.core as htcore
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
@ -36,37 +37,31 @@ class StaticWarper:
|
|||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
|
||||||
self.cuda_graph = None
|
self.hpu_graph = None
|
||||||
self.static_scores = None
|
self.static_scores = None
|
||||||
self.static_warped_scores = None
|
self.static_warped_scores = None
|
||||||
self.static_next_logprob = None
|
self.static_next_logprob = None
|
||||||
|
|
||||||
def __call__(self, scores):
|
def __call__(self, scores):
|
||||||
if torch.cuda.is_available():
|
if self.hpu_graph is None:
|
||||||
if self.cuda_graph is None:
|
self.static_scores = scores.clone().contiguous()
|
||||||
self.static_scores = scores
|
self.static_warped_scores = scores.clone().contiguous()
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.static_next_logprob = scores.clone().contiguous()
|
||||||
|
self.hpu_graph = htcore.hpu.HPUGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
with htcore.hpu.graph(self.hpu_graph):
|
||||||
local_scores = self.static_scores
|
local_scores = self.static_scores
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
local_scores = warper(None, local_scores)
|
local_scores = warper(None, local_scores)
|
||||||
|
|
||||||
self.static_warped_scores = local_scores
|
self.static_warped_scores.copy_(local_scores)
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
self.static_next_logprob = torch.log_softmax(
|
self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1))
|
||||||
self.static_warped_scores, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.static_scores.copy_(scores)
|
self.static_scores.copy_(scores)
|
||||||
self.cuda_graph.replay()
|
self.hpu_graph.replay()
|
||||||
|
|
||||||
return self.static_warped_scores, self.static_next_logprob
|
return self.static_warped_scores, self.static_next_logprob
|
||||||
|
|
||||||
# CPU branch
|
|
||||||
for warper in self.warpers:
|
|
||||||
scores = warper(None, scores)
|
|
||||||
return scores, torch.log_softmax(scores, -1)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
@ -76,9 +71,7 @@ def static_warper(
|
|||||||
top_p: Optional[float],
|
top_p: Optional[float],
|
||||||
typical_p: Optional[float],
|
typical_p: Optional[float],
|
||||||
) -> StaticWarper:
|
) -> StaticWarper:
|
||||||
return StaticWarper(
|
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
@ -95,17 +88,13 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
self.penalty_tensor = torch.tensor(
|
self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
|
||||||
penalty, dtype=dtype, device=device
|
|
||||||
).unsqueeze(1)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
score = torch.where(
|
score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor)
|
||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
@ -129,13 +118,9 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
The value used to module the logits distribution.
|
The value used to module the logits distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
|
||||||
):
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.temperature_tensor = torch.tensor(
|
self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1)
|
||||||
temperature, dtype=dtype, device=device
|
|
||||||
).unsqueeze(1)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
scores.div_(self.temperature_tensor)
|
scores.div_(self.temperature_tensor)
|
||||||
@ -174,9 +159,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
):
|
):
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_p_opposite = 1 - torch.tensor(
|
self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1)
|
||||||
top_p, dtype=dtype, device=device
|
|
||||||
).unsqueeze(1)
|
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
@ -193,9 +176,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
return warped_scores
|
return warped_scores
|
||||||
@ -243,9 +224,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
disabled = [x == 0 for x in top_k]
|
disabled = [x == 0 for x in top_k]
|
||||||
|
|
||||||
if any(disabled):
|
if any(disabled):
|
||||||
self.top_k_disabled_mask = torch.tensor(
|
self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1)
|
||||||
disabled, dtype=torch.bool, device=device
|
|
||||||
).view(-1, 1)
|
|
||||||
else:
|
else:
|
||||||
self.top_k_disabled_mask = None
|
self.top_k_disabled_mask = None
|
||||||
|
|
||||||
@ -281,9 +260,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
self.max_top_k = max(self.top_k)
|
self.max_top_k = max(self.top_k)
|
||||||
|
|
||||||
if self.top_k_disabled_mask is not None:
|
if self.top_k_disabled_mask is not None:
|
||||||
self.top_k_disabled_mask = (
|
self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||||
self.top_k_disabled_mask[indices] if any(disabled) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
@ -349,15 +326,11 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
if self.disabled_mask is not None:
|
if self.disabled_mask is not None:
|
||||||
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||||
|
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||||
1, last_ind.view(-1, 1)
|
|
||||||
)
|
|
||||||
if self.min_tokens_to_keep > 1:
|
if self.min_tokens_to_keep > 1:
|
||||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
|
|
||||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
@ -371,9 +344,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
|||||||
self.mass_tensor = self.mass_tensor[indices]
|
self.mass_tensor = self.mass_tensor[indices]
|
||||||
|
|
||||||
if self.disabled_mask is not None:
|
if self.disabled_mask is not None:
|
||||||
self.disabled_mask = (
|
self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None
|
||||||
self.disabled_mask[indices] if any(disabled) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
|
@ -30,13 +30,9 @@ class NextTokenChooser:
|
|||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
|
||||||
)
|
|
||||||
self.repetition_processor = (
|
self.repetition_processor = (
|
||||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty else None
|
||||||
if repetition_penalty
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
@ -46,9 +42,7 @@ class NextTokenChooser:
|
|||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
)
|
)
|
||||||
if has_warpers:
|
if has_warpers:
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
|
||||||
@ -136,9 +130,7 @@ class StoppingCriteria:
|
|||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [StopSequenceCriteria(sequence) for sequence in pb.stop_sequences]
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
|
||||||
]
|
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
tokenizer.eos_token_id,
|
tokenizer.eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
@ -176,20 +168,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.repetition_processor = (
|
self.repetition_processor = (
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, dtype, device)
|
||||||
repetition_penalty, dtype, device
|
|
||||||
)
|
|
||||||
if any([x != 1.0 for x in repetition_penalty])
|
if any([x != 1.0 for x in repetition_penalty])
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [sample or x != 1.0 for x, sample in zip(temperature, do_sample)]
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, dtype, device))
|
||||||
]
|
|
||||||
warpers.append(
|
|
||||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
|
||||||
)
|
|
||||||
|
|
||||||
if any([x != 0 for x in top_k]):
|
if any([x != 0 for x in top_k]):
|
||||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
@ -277,7 +263,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
self.generator = torch.Generator(device)
|
self.generator = torch.Generator("cpu")
|
||||||
self.generator.manual_seed(seed)
|
self.generator.manual_seed(seed)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
@ -355,30 +341,21 @@ def batch_top_tokens(
|
|||||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||||
nth_highest = torch.gather(
|
nth_highest = torch.gather(sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1))
|
||||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
|
||||||
)
|
|
||||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||||
|
|
||||||
# Find the new "fuzzy" top n values
|
# Find the new "fuzzy" top n values
|
||||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||||
|
|
||||||
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
|
||||||
# Take a new topk for these new max n values
|
# Take a new topk for these new max n values
|
||||||
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||||
|
|
||||||
top_n_ishes = top_n_ishes.tolist()
|
top_n_ishes = top_n_ishes.tolist()
|
||||||
top_indices = top_k.indices.tolist()
|
top_indices = top_k.indices.tolist()
|
||||||
top_values = top_k.values.tolist()
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
[
|
[idxs[:n] if req_n > 0 else [] for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)],
|
||||||
idxs[:n] if req_n > 0 else []
|
[vals[:n] if req_n > 0 else [] for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)],
|
||||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
|
||||||
],
|
|
||||||
[
|
|
||||||
vals[:n] if req_n > 0 else []
|
|
||||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
@ -34,21 +34,17 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
# watermarking parameters
|
# watermarking parameters
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
self.rng = torch.Generator(device=device)
|
self.rng = torch.Generator(device="cpu")
|
||||||
self.hash_key = hash_key
|
self.hash_key = hash_key
|
||||||
|
|
||||||
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
||||||
if isinstance(input_ids, list):
|
if isinstance(input_ids, list):
|
||||||
assert (
|
assert len(input_ids) >= 1, "requires at least a 1 token prefix sequence to seed rng"
|
||||||
len(input_ids) >= 1
|
|
||||||
), "requires at least a 1 token prefix sequence to seed rng"
|
|
||||||
prev_token = input_ids[-1]
|
prev_token = input_ids[-1]
|
||||||
else:
|
else:
|
||||||
assert len(input_ids) == 1
|
assert len(input_ids) == 1
|
||||||
input_ids = input_ids[0]
|
input_ids = input_ids[0]
|
||||||
assert (
|
assert input_ids.shape[-1] >= 1, "requires at least a 1 token prefix sequence to seed rng"
|
||||||
input_ids.shape[-1] >= 1
|
|
||||||
), "requires at least a 1 token prefix sequence to seed rng"
|
|
||||||
prev_token = input_ids[-1].item()
|
prev_token = input_ids[-1].item()
|
||||||
self.rng.manual_seed(self.hash_key * prev_token)
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
|
|
||||||
@ -67,9 +63,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
return greenlist_ids
|
return greenlist_ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_greenlist_mask(
|
def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
|
||||||
scores: torch.FloatTensor, greenlist_token_ids
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
green_tokens_mask = torch.zeros_like(scores)
|
green_tokens_mask = torch.zeros_like(scores)
|
||||||
green_tokens_mask[-1, greenlist_token_ids] = 1
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||||
final_mask = green_tokens_mask.bool()
|
final_mask = green_tokens_mask.bool()
|
||||||
@ -82,15 +76,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
|
greenlist_ids = self._get_greenlist_ids(input_ids, scores.shape[-1], scores.device)
|
||||||
) -> torch.FloatTensor:
|
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=greenlist_ids)
|
||||||
greenlist_ids = self._get_greenlist_ids(
|
|
||||||
input_ids, scores.shape[-1], scores.device
|
|
||||||
)
|
|
||||||
green_tokens_mask = self._calc_greenlist_mask(
|
|
||||||
scores=scores, greenlist_token_ids=greenlist_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = self._bias_greenlist_logits(
|
scores = self._bias_greenlist_logits(
|
||||||
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
||||||
|
Loading…
Reference in New Issue
Block a user