mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs https://github.com/huggingface/transformers/pull/28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
255 lines
8.4 KiB
Docker
255 lines
8.4 KiB
Docker
# Rust builder
|
|
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
|
WORKDIR /usr/src
|
|
|
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
|
|
|
FROM chef as planner
|
|
COPY Cargo.toml Cargo.toml
|
|
COPY rust-toolchain.toml rust-toolchain.toml
|
|
COPY proto proto
|
|
COPY benchmark benchmark
|
|
COPY router router
|
|
COPY launcher launcher
|
|
RUN cargo chef prepare --recipe-path recipe.json
|
|
|
|
FROM chef AS builder
|
|
|
|
ARG GIT_SHA
|
|
ARG DOCKER_LABEL
|
|
|
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
|
rm -f $PROTOC_ZIP
|
|
|
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
|
RUN cargo chef cook --release --recipe-path recipe.json
|
|
|
|
COPY Cargo.toml Cargo.toml
|
|
COPY rust-toolchain.toml rust-toolchain.toml
|
|
COPY proto proto
|
|
COPY benchmark benchmark
|
|
COPY router router
|
|
COPY launcher launcher
|
|
RUN cargo build --release
|
|
|
|
# Python builder
|
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
|
FROM nvidia/cuda:12.1.0-devel-ubuntu20.04 as pytorch-install
|
|
|
|
ARG PYTORCH_VERSION=2.1.1
|
|
ARG PYTHON_VERSION=3.10
|
|
# Keep in sync with `server/pyproject.toml
|
|
ARG CUDA_VERSION=12.1
|
|
ARG MAMBA_VERSION=23.3.1-1
|
|
ARG CUDA_CHANNEL=nvidia
|
|
ARG INSTALL_CHANNEL=pytorch
|
|
# Automatically set by buildx
|
|
ARG TARGETPLATFORM
|
|
|
|
ENV PATH /opt/conda/bin:$PATH
|
|
|
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
build-essential \
|
|
ca-certificates \
|
|
ccache \
|
|
curl \
|
|
git && \
|
|
rm -rf /var/lib/apt/lists/*
|
|
|
|
# Install conda
|
|
# translating Docker's TARGETPLATFORM into mamba arches
|
|
RUN case ${TARGETPLATFORM} in \
|
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
|
*) MAMBA_ARCH=x86_64 ;; \
|
|
esac && \
|
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
|
RUN chmod +x ~/mambaforge.sh && \
|
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
|
rm ~/mambaforge.sh
|
|
|
|
# Install pytorch
|
|
# On arm64 we exit with an error code
|
|
RUN case ${TARGETPLATFORM} in \
|
|
"linux/arm64") exit 1 ;; \
|
|
*) /opt/conda/bin/conda update -y conda && \
|
|
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
|
esac && \
|
|
/opt/conda/bin/conda clean -ya
|
|
|
|
# CUDA kernels builder image
|
|
FROM pytorch-install as kernel-builder
|
|
|
|
ARG MAX_JOBS=8
|
|
|
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
ninja-build \
|
|
&& rm -rf /var/lib/apt/lists/*
|
|
|
|
# Build Flash Attention CUDA kernels
|
|
FROM kernel-builder as flash-att-builder
|
|
|
|
WORKDIR /usr/src
|
|
|
|
COPY server/Makefile-flash-att Makefile
|
|
|
|
# Build specific version of flash attention
|
|
RUN make build-flash-attention
|
|
|
|
# Build Flash Attention v2 CUDA kernels
|
|
FROM kernel-builder as flash-att-v2-builder
|
|
|
|
WORKDIR /usr/src
|
|
|
|
COPY server/Makefile-flash-att-v2 Makefile
|
|
|
|
# Build specific version of flash attention v2
|
|
RUN make build-flash-attention-v2-cuda
|
|
|
|
# Build Transformers exllama kernels
|
|
FROM kernel-builder as exllama-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/exllama_kernels/ .
|
|
|
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
|
|
|
# Build Transformers exllama kernels
|
|
FROM kernel-builder as exllamav2-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/exllamav2_kernels/ .
|
|
|
|
# Build specific version of transformers
|
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
|
|
|
# Build Transformers awq kernels
|
|
FROM kernel-builder as awq-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/Makefile-awq Makefile
|
|
# Build specific version of transformers
|
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
|
|
|
|
# Build eetq kernels
|
|
FROM kernel-builder as eetq-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/Makefile-eetq Makefile
|
|
# Build specific version of transformers
|
|
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
|
|
|
# Build Transformers CUDA kernels
|
|
FROM kernel-builder as custom-kernels-builder
|
|
WORKDIR /usr/src
|
|
COPY server/custom_kernels/ .
|
|
# Build specific version of transformers
|
|
RUN python setup.py build
|
|
|
|
# Build vllm CUDA kernels
|
|
FROM kernel-builder as vllm-builder
|
|
|
|
WORKDIR /usr/src
|
|
|
|
COPY server/Makefile-vllm Makefile
|
|
|
|
# Build specific version of vllm
|
|
RUN make build-vllm-cuda
|
|
|
|
# Build mamba kernels
|
|
FROM kernel-builder as mamba-builder
|
|
WORKDIR /usr/src
|
|
COPY server/Makefile-selective-scan Makefile
|
|
RUN make build-all
|
|
|
|
# Build megablocks
|
|
FROM kernel-builder as megablocks-builder
|
|
|
|
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
|
|
|
# Text Generation Inference base image
|
|
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
|
|
|
# Conda env
|
|
ENV PATH=/opt/conda/bin:$PATH \
|
|
CONDA_PREFIX=/opt/conda
|
|
|
|
# Text Generation Inference base env
|
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
|
PORT=80
|
|
|
|
WORKDIR /usr/src
|
|
|
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
libssl-dev \
|
|
ca-certificates \
|
|
make \
|
|
curl \
|
|
&& rm -rf /var/lib/apt/lists/*
|
|
|
|
# Copy conda with PyTorch and Megablocks installed
|
|
COPY --from=megablocks-builder /opt/conda /opt/conda
|
|
|
|
# Copy build artifacts from flash attention builder
|
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from flash attention v2 builder
|
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from custom kernels builder
|
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
# Copy build artifacts from exllama kernels builder
|
|
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
# Copy build artifacts from exllamav2 kernels builder
|
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
# Copy build artifacts from awq kernels builder
|
|
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
# Copy build artifacts from eetq kernels builder
|
|
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy builds artifacts from vllm builder
|
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Copy build artifacts from mamba builder
|
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
|
|
|
# Install flash-attention dependencies
|
|
RUN pip install einops --no-cache-dir
|
|
|
|
# Install server
|
|
COPY proto proto
|
|
COPY server server
|
|
COPY server/Makefile server/Makefile
|
|
RUN cd server && \
|
|
make gen-server && \
|
|
pip install -r requirements_cuda.txt && \
|
|
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
|
|
|
|
# Install benchmarker
|
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
|
# Install router
|
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
|
# Install launcher
|
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
|
|
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
|
build-essential \
|
|
g++ \
|
|
&& rm -rf /var/lib/apt/lists/*
|
|
|
|
# AWS Sagemaker compatible image
|
|
FROM base as sagemaker
|
|
|
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
|
RUN chmod +x entrypoint.sh
|
|
|
|
ENTRYPOINT ["./entrypoint.sh"]
|
|
|
|
# Final image
|
|
FROM base
|
|
|
|
ENTRYPOINT ["text-generation-launcher"]
|
|
CMD ["--json-output"]
|