mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into gpt_awq_4
This commit is contained in:
commit
92fa7ac7e9
@ -4,3 +4,4 @@ server/transformers
|
|||||||
server/flash-attention
|
server/flash-attention
|
||||||
cmake-build-debug/
|
cmake-build-debug/
|
||||||
cmake-build-release/
|
cmake-build-release/
|
||||||
|
Dockerfile*
|
||||||
|
2
.github/workflows/build.yaml
vendored
2
.github/workflows/build.yaml
vendored
@ -45,7 +45,7 @@ jobs:
|
|||||||
export dockerfile="Dockerfile"
|
export dockerfile="Dockerfile"
|
||||||
export label_extension=""
|
export label_extension=""
|
||||||
export docker_devices=""
|
export docker_devices=""
|
||||||
export runs_on="aws-g6-12xlarge-plus-priv"
|
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||||
export platform=""
|
export platform=""
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
|
4
.github/workflows/tests.yaml
vendored
4
.github/workflows/tests.yaml
vendored
@ -42,6 +42,7 @@ jobs:
|
|||||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
|
sudo apt update
|
||||||
sudo apt install python3.11-dev -y
|
sudo apt install python3.11-dev -y
|
||||||
make install-cpu
|
make install-cpu
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
@ -57,3 +58,6 @@ jobs:
|
|||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
- name: Run Rust tests with google feature
|
||||||
|
run: |
|
||||||
|
cargo test --features google
|
||||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,9 +3,8 @@ target
|
|||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
backends/v2/src/client/pb
|
||||||
backends/v3/src/client/pb
|
backends/v3/src/client/pb
|
||||||
backends/client/src/v2/pb
|
|
||||||
backends/client/src/v3/pb
|
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
|
@ -23,9 +23,11 @@ docs/openapi.json:
|
|||||||
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||||
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||||
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
|
||||||
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||||
- '#/components/schemas/ToolChoice/nullable'
|
- '#/components/schemas/ToolChoice/nullable'
|
||||||
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
|
||||||
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||||
no-invalid-media-type-examples:
|
no-invalid-media-type-examples:
|
||||||
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||||
|
677
Cargo.lock
generated
677
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
13
Cargo.toml
13
Cargo.toml
@ -1,24 +1,26 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"backends/trtllm",
|
"backends/trtllm",
|
||||||
"backends/client",
|
"launcher",
|
||||||
"launcher"
|
"router"
|
||||||
]
|
]
|
||||||
default-members = [
|
default-members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
# "backends/trtllm",
|
# "backends/trtllm",
|
||||||
"backends/client",
|
"launcher",
|
||||||
"launcher"
|
"router"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
@ -31,6 +33,7 @@ metrics = { version = "0.23.0" }
|
|||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
minijinja = { version = "2.2.0", features = ["json"] }
|
minijinja = { version = "2.2.0", features = ["json"] }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
|
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
@ -40,7 +40,6 @@ COPY router router
|
|||||||
COPY backends backends
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
RUN cargo build --profile release-opt
|
|
||||||
|
|
||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
@ -258,7 +257,7 @@ COPY server/Makefile server/Makefile
|
|||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
24
Dockerfile.nix
Normal file
24
Dockerfile.nix
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Build the image and get out the docker file:
|
||||||
|
#
|
||||||
|
# docker build -t tgi-nix-builder -f Dockerfile.nix
|
||||||
|
# docker run --log-driver=none tgi-nix-builder | docker load
|
||||||
|
|
||||||
|
FROM nixos/nix:2.18.8 AS builder
|
||||||
|
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||||
|
RUN nix profile install nixpkgs#cachix
|
||||||
|
RUN cachix use text-generation-inference
|
||||||
|
WORKDIR /root
|
||||||
|
ADD . .
|
||||||
|
RUN nix build .
|
||||||
|
RUN mkdir /tmp/nix-store-closure
|
||||||
|
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
|
||||||
|
|
||||||
|
FROM ubuntu:24.04
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy /nix/store
|
||||||
|
COPY --from=builder /tmp/nix-store-closure /nix/store
|
||||||
|
COPY --from=builder /root/result /app
|
||||||
|
RUN ldconfig
|
||||||
|
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
|
175
Dockerfile_amd
175
Dockerfile_amd
@ -41,7 +41,7 @@ COPY launcher launcher
|
|||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
git \
|
git \
|
||||||
make \
|
make \
|
||||||
|
libmsgpack-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
llvm-dev \
|
||||||
g++ \
|
g++ \
|
||||||
# Needed to build VLLM & flash.
|
# Needed to build VLLM & flash.
|
||||||
rocthrust-dev \
|
rocthrust-dev \
|
||||||
hipsparse-dev \
|
hipsparse-dev \
|
||||||
hipblas-dev \
|
hipblas-dev \
|
||||||
hipblaslt-dev \
|
hipcub-dev \
|
||||||
rocblas-dev \
|
rocblas-dev \
|
||||||
hiprand-dev \
|
hiprand-dev \
|
||||||
|
hipfft-dev \
|
||||||
rocrand-dev \
|
rocrand-dev \
|
||||||
miopen-hip-dev \
|
miopen-hip-dev \
|
||||||
hipfft-dev \
|
|
||||||
hipcub-dev \
|
|
||||||
hipsolver-dev \
|
hipsolver-dev \
|
||||||
rccl-dev \
|
rccl-dev \
|
||||||
cmake \
|
cmake \
|
||||||
python3.11-dev && \
|
python3.11-venv && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
ARG PYTORCH_VERSION='2.3.0'
|
|
||||||
ARG ROCM_VERSION='6.0.2'
|
|
||||||
ARG PYTHON_VERSION='3.11.10'
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
# Install mamba
|
# Install mamba
|
||||||
@ -100,41 +101,132 @@ RUN case ${TARGETPLATFORM} in \
|
|||||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
esac && \
|
esac && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
# Install flash-attention, torch dependencies
|
# Install flash-attention, torch dependencies
|
||||||
RUN pip install numpy einops ninja --no-cache-dir
|
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN pip uninstall -y triton && \
|
RUN conda install mkl=2021
|
||||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||||
cd triton/python && \
|
|
||||||
pip install .
|
|
||||||
|
|
||||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
|
||||||
|
|
||||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
ARG COMMON_WORKDIR=/
|
||||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
WORKDIR ${COMMON_WORKDIR}
|
||||||
|
|
||||||
|
|
||||||
|
# Install HIPBLASLt
|
||||||
|
FROM base AS build_hipblaslt
|
||||||
|
ARG HIPBLASLT_BRANCH="e6da924"
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||||
|
&& cd hipBLASLt \
|
||||||
|
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||||
|
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||||
|
&& cd build/release \
|
||||||
|
&& make package
|
||||||
|
|
||||||
|
FROM scratch AS export_hipblaslt
|
||||||
|
ARG COMMON_WORKDIR
|
||||||
|
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||||
|
|
||||||
|
# RCCL build stages
|
||||||
|
FROM base AS build_rccl
|
||||||
|
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||||
|
RUN git clone https://github.com/ROCm/rccl \
|
||||||
|
&& cd rccl \
|
||||||
|
&& git checkout ${RCCL_BRANCH} \
|
||||||
|
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||||
|
FROM scratch AS export_rccl
|
||||||
|
ARG COMMON_WORKDIR
|
||||||
|
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||||
|
|
||||||
|
# Triton build stages
|
||||||
|
FROM base AS build_triton
|
||||||
|
ARG TRITON_BRANCH="e192dba"
|
||||||
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||||
|
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||||
|
&& cd triton \
|
||||||
|
&& git checkout ${TRITON_BRANCH} \
|
||||||
|
&& cd python \
|
||||||
|
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
FROM scratch AS export_triton
|
||||||
|
ARG COMMON_WORKDIR
|
||||||
|
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||||
|
|
||||||
|
# # AMD-SMI build stages
|
||||||
|
FROM base AS build_amdsmi
|
||||||
|
RUN cd /opt/rocm/share/amd_smi \
|
||||||
|
&& pip wheel . --wheel-dir=dist
|
||||||
|
FROM scratch AS export_amdsmi
|
||||||
|
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||||
|
|
||||||
|
|
||||||
|
FROM base as build_pytorch
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||||
|
if ls /install/*.deb; then \
|
||||||
|
dpkg -i /install/*.deb \
|
||||||
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||||
ARG BUILD_CAFFE2="0" \
|
|
||||||
BUILD_CAFFE2_OPS="0" \
|
|
||||||
USE_CUDA="0" \
|
|
||||||
USE_ROCM="1" \
|
|
||||||
BUILD_TEST="0" \
|
|
||||||
USE_FBGEMM="0" \
|
|
||||||
USE_NNPACK="0" \
|
|
||||||
USE_QNNPACK="0" \
|
|
||||||
USE_XNNPACK="0" \
|
|
||||||
USE_FLASH_ATTENTION="1" \
|
|
||||||
USE_MEM_EFF_ATTENTION="0"
|
|
||||||
|
|
||||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||||
|
# Not yet in 2.5.0-rc1
|
||||||
|
ARG PYTORCH_BRANCH="cedc116"
|
||||||
|
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||||
|
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||||
|
|
||||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||||
ENV HIP_FORCE_DEV_KERNARG=1
|
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||||
|
&& pip install -r requirements.txt --no-cache-dir \
|
||||||
|
&& python tools/amd_build/build_amd.py \
|
||||||
|
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
FROM scratch as export_pytorch
|
||||||
|
ARG COMMON_WORKDIR
|
||||||
|
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||||
|
|
||||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
FROM base AS install_deps
|
||||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
|
||||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
ARG COMMON_WORKDIR
|
||||||
|
|
||||||
|
# Install hipblaslt
|
||||||
|
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||||
|
if ls /install/*.deb; then \
|
||||||
|
dpkg -i /install/*.deb \
|
||||||
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||||
|
if ls /install/*.deb; then \
|
||||||
|
dpkg -i /install/*.deb \
|
||||||
|
# RCCL needs to be installed twice
|
||||||
|
&& dpkg -i /install/*.deb \
|
||||||
|
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||||
|
if ls /install/*.whl; then \
|
||||||
|
# Preemptively uninstall to prevent pip same-version no-installs
|
||||||
|
pip uninstall -y triton \
|
||||||
|
&& pip install /install/*.whl; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||||
|
# Preemptively uninstall to prevent pip same-version no-installs
|
||||||
|
pip uninstall -y amdsmi \
|
||||||
|
&& pip install /install/*.whl;
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||||
|
if ls /install/*.whl; then \
|
||||||
|
# Preemptively uninstall to prevent pip same-version no-installs
|
||||||
|
pip uninstall -y torch torchvision \
|
||||||
|
&& pip install /install/*.whl; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
FROM install_deps AS kernel-builder
|
||||||
|
|
||||||
# # Build vllm kernels
|
# # Build vllm kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
@ -174,7 +266,7 @@ COPY server/exllamav2_kernels/ .
|
|||||||
|
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
FROM base AS base-copy
|
FROM install_deps AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
@ -224,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
|
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||||
|
ENV HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|
||||||
|
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||||
|
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||||
|
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||||
|
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
|
ENV VLLM_MOE_PADDING=0
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV USE_PREFIX_CACHING=0
|
||||||
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
|||||||
volume=$PWD/data
|
volume=$PWD/data
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model
|
ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
|
17
_server.nix
17
_server.nix
@ -1,17 +0,0 @@
|
|||||||
{
|
|
||||||
mkPoetryApplication,
|
|
||||||
pkg-config,
|
|
||||||
protobuf,
|
|
||||||
openssl,
|
|
||||||
}:
|
|
||||||
|
|
||||||
mkPoetryApplication {
|
|
||||||
# name = "text-generation-server";
|
|
||||||
|
|
||||||
projectDir = ./server;
|
|
||||||
|
|
||||||
# nativeBuildInputs = [ pkg-config ];
|
|
||||||
|
|
||||||
# buildInputs = [ openssl.dev protobuf ];
|
|
||||||
|
|
||||||
}
|
|
75
backends/v2/Cargo.toml
Normal file
75
backends/v2/Cargo.toml
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
[package]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
description = "Text Generation Webserver"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
axum = { version = "0.7", features = ["json"] }
|
||||||
|
axum-tracing-opentelemetry = "0.16"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
|
futures = "0.3.28"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
|
metrics = { workspace = true }
|
||||||
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
|
serde = "1.0.188"
|
||||||
|
serde_json = "1.0.107"
|
||||||
|
slotmap = "1.0.7"
|
||||||
|
thiserror = "1.0.48"
|
||||||
|
tokenizers = { workspace = true }
|
||||||
|
tokio = { version = "1.32.0", features = [
|
||||||
|
"rt",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"parking_lot",
|
||||||
|
"signal",
|
||||||
|
"sync",
|
||||||
|
] }
|
||||||
|
tokio-stream = "0.1.14"
|
||||||
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-opentelemetry = "0.21.0"
|
||||||
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
|
"opentelemetry-otlp",
|
||||||
|
] }
|
||||||
|
minijinja = { workspace = true }
|
||||||
|
minijinja-contrib = { workspace = true }
|
||||||
|
futures-util = "0.3.30"
|
||||||
|
regex = "1.10.3"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
image = "0.25.1"
|
||||||
|
base64 = { workspace = true }
|
||||||
|
prost = "^0.12"
|
||||||
|
tonic = "^0.10"
|
||||||
|
tower = "^0.4"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.10.1"
|
||||||
|
prost-build = "0.12.1"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ngrok"]
|
||||||
|
ngrok = ["text-generation-router/ngrok"]
|
||||||
|
google = ["text-generation-router/google"]
|
||||||
|
kserve = ["text-generation-router/kserve"]
|
19
backends/v2/build.rs
Normal file
19
backends/v2/build.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use std::fs;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("cargo:rerun-if-changed=../../proto/");
|
||||||
|
|
||||||
|
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/client/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||||
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
506
backends/v2/src/backend.rs
Normal file
506
backends/v2/src/backend.rs
Normal file
@ -0,0 +1,506 @@
|
|||||||
|
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::queue::{Entry, Queue};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
pub struct BackendV2 {
|
||||||
|
/// Request queue
|
||||||
|
queue: Queue,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Client clone, used for health checks to skip the queue
|
||||||
|
client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendV2 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
|
// Infer shared state
|
||||||
|
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||||
|
attention
|
||||||
|
.parse()
|
||||||
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||||
|
} else {
|
||||||
|
Attention::Paged
|
||||||
|
};
|
||||||
|
let block_size = if attention == Attention::FlashDecoding {
|
||||||
|
256
|
||||||
|
} else {
|
||||||
|
16
|
||||||
|
};
|
||||||
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client.clone(),
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
queue,
|
||||||
|
batching_task_notifier,
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for BackendV2 {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
if current_health {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
queue: Queue,
|
||||||
|
notifier: Arc<Notify>,
|
||||||
|
) {
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
notifier.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the queue
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the queue
|
||||||
|
while let Some((mut entries, batch, span)) = queue
|
||||||
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
// Tracking metrics
|
||||||
|
if min_size.is_some() {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_size = entries.len();
|
||||||
|
let next_batch_span =
|
||||||
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
|
.instrument(next_batch_span)
|
||||||
|
.await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn prefill(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
for id in batch_ids {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_batch(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
next_batch: Option<CachedBatch>,
|
||||||
|
entries: &IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
|
// No need to filter
|
||||||
|
if batch.size as usize == entries.len() {
|
||||||
|
return Some(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = batch.id;
|
||||||
|
|
||||||
|
// Retain only requests that are still in entries
|
||||||
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
|
if batch.request_ids.is_empty() {
|
||||||
|
// All requests have been filtered out
|
||||||
|
// Next batch is now empty
|
||||||
|
// Clear it from the Python shards cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Filter Python shard cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
/// and filter entries
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
let id = generation.request_id;
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
// Send generation responses back to the infer task
|
||||||
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
|
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
}).unwrap_or(true);
|
||||||
|
if stopped {
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
|
let mut iterator = tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs)
|
||||||
|
.zip(tokens_.texts)
|
||||||
|
.zip(tokens_.is_special)
|
||||||
|
.enumerate()
|
||||||
|
.peekable();
|
||||||
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.iter()
|
||||||
|
.zip(top_tokens_.logprobs.iter())
|
||||||
|
.zip(top_tokens_.texts.iter())
|
||||||
|
.zip(top_tokens_.is_special.iter())
|
||||||
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
|
id,
|
||||||
|
text: text.to_string(),
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
|
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v2_finish_reason {
|
||||||
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
257
backends/v2/src/client/grpc_client.rs
Normal file
257
backends/v2/src/client/grpc_client.rs
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
/// Single shard Client
|
||||||
|
use crate::client::pb;
|
||||||
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v2::*;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::transport::{Channel, Uri};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC client
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Returns a client connected to the given url
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
|
.unwrap()
|
||||||
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||||
|
})?;
|
||||||
|
let urls = response
|
||||||
|
.into_inner()
|
||||||
|
.urls
|
||||||
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
|
None => url,
|
||||||
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
|
self.stub.clear_cache(request).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
|
batch_id,
|
||||||
|
request_ids,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
Ok(filtered_batch.batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str(&format!(
|
||||||
|
"",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
inputs,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
truncate,
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
// Check max_batch_size
|
||||||
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(response.max_supported_total_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
68
backends/v2/src/client/mod.rs
Normal file
68
backends/v2/src/client/mod.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tonic::transport;
|
||||||
|
use tonic::Status;
|
||||||
|
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod grpc_client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use grpc_client::Client;
|
||||||
|
pub use pb::generate::v2::{
|
||||||
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||||
|
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug, Clone)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
Connection(String),
|
||||||
|
#[error("Server error: {0}")]
|
||||||
|
Generation(String),
|
||||||
|
#[error("Sharded results are empty")]
|
||||||
|
EmptyResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Status> for ClientError {
|
||||||
|
fn from(err: Status) -> Self {
|
||||||
|
let err = Self::Generation(err.message().to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<transport::Error> for ClientError {
|
||||||
|
fn from(err: transport::Error) -> Self {
|
||||||
|
let err = Self::Connection(err.to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
252
backends/v2/src/client/sharded_client.rs
Normal file
252
backends/v2/src/client/sharded_client.rs
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
/// Multi shard Client
|
||||||
|
use crate::client::{ClientError, Result};
|
||||||
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
|
use crate::client::InfoResponse;
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
141
backends/v2/src/lib.rs
Normal file
141
backends/v2/src/lib.rs
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
mod backend;
|
||||||
|
mod client;
|
||||||
|
mod queue;
|
||||||
|
|
||||||
|
use crate::client::{ClientError, ShardedClient};
|
||||||
|
pub(crate) use backend::BackendV2;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct BackendInfo {
|
||||||
|
/// Mandatory
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
|
||||||
|
/// Backend parameters
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub speculate: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn connect_backend(
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||||
|
// Helper function
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v2
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Warmup)?,
|
||||||
|
)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
|
let backend_info = BackendInfo {
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
model_dtype: shard_info.dtype.clone(),
|
||||||
|
speculate: shard_info.speculate as usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backend = BackendV2::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("Using backend V3");
|
||||||
|
|
||||||
|
Ok((backend, backend_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum V2Error {
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
}
|
212
backends/v2/src/main.rs
Normal file
212
backends/v2/src/main.rs
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use text_generation_router::{server, usage_stats};
|
||||||
|
use text_generation_router_v2::{connect_backend, V2Error};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
command,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
master_shard_uds_path,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
};
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
if max_batch_size == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_batch_size` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (backend, _backend_info) = connect_backend(
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
master_shard_uds_path,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Backend failed: {0}")]
|
||||||
|
Backend(#[from] V2Error),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
}
|
@ -1,14 +1,14 @@
|
|||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::client::{
|
||||||
use crate::validation::{
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::v2::{
|
use text_generation_router::infer::InferError;
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::validation::{
|
||||||
|
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use text_generation_client::ChunksToString;
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
@ -218,7 +218,7 @@ impl State {
|
|||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(&Span::current());
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
@ -415,7 +416,9 @@ mod tests {
|
|||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
parameters: ValidParameters {
|
parameters: ValidParameters {
|
@ -100,6 +100,7 @@ pub async fn connect_backend(
|
|||||||
.map_err(V3Error::Warmup)?,
|
.map_err(V3Error::Warmup)?,
|
||||||
)?;
|
)?;
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||||
|
|
||||||
let backend_info = BackendInfo {
|
let backend_info = BackendInfo {
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
@ -364,7 +364,7 @@ impl State {
|
|||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
Some(block_allocation) => {
|
Some(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
@ -436,6 +436,12 @@ impl State {
|
|||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch_requests.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
@ -1,10 +1,22 @@
|
|||||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashMap},
|
collections::{BTreeSet, HashMap},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn hash(slice: &[u32]) -> u64 {
|
||||||
|
assert!(!slice.is_empty());
|
||||||
|
if slice.len() == 1 {
|
||||||
|
slice[0] as u64
|
||||||
|
} else {
|
||||||
|
let mut s = std::hash::DefaultHasher::new();
|
||||||
|
slice.hash(&mut s);
|
||||||
|
s.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct RadixAllocator {
|
pub struct RadixAllocator {
|
||||||
allocation_id: u64,
|
allocation_id: u64,
|
||||||
|
|
||||||
@ -44,6 +56,10 @@ impl RadixAllocator {
|
|||||||
// the free list if we cannot allocate enough blocks. This is only
|
// the free list if we cannot allocate enough blocks. This is only
|
||||||
// temporary, the trie needs to be able to report whether it can
|
// temporary, the trie needs to be able to report whether it can
|
||||||
// allocate the requested amount. Just not implemented yet.
|
// allocate the requested amount. Just not implemented yet.
|
||||||
|
tracing::debug!(
|
||||||
|
"Free blocks {} need {n_blocks_needed}",
|
||||||
|
self.free_blocks.len()
|
||||||
|
);
|
||||||
self.free_blocks.extend(
|
self.free_blocks.extend(
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||||
@ -94,6 +110,9 @@ impl Allocator for RadixAllocator {
|
|||||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
|
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
||||||
|
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
||||||
|
tracing::debug!("Block size {}", self.block_size);
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
.decref(prefix_node)
|
.decref(prefix_node)
|
||||||
.expect("Failed to decrement refcount");
|
.expect("Failed to decrement refcount");
|
||||||
@ -211,7 +230,6 @@ struct RadixAllocation {
|
|||||||
pub enum TrieError {
|
pub enum TrieError {
|
||||||
InvalidNodeId,
|
InvalidNodeId,
|
||||||
RefCountUnderflow,
|
RefCountUnderflow,
|
||||||
BlockTokenCountMismatch,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type NodeId = DefaultKey;
|
pub type NodeId = DefaultKey;
|
||||||
@ -268,7 +286,9 @@ impl RadixTrie {
|
|||||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
if key.len() >= self.block_size {
|
||||||
|
let node_key = hash(&key[..self.block_size]);
|
||||||
|
if let Some(&child_id) = node.children.get(&node_key) {
|
||||||
self.update_access_time(child_id);
|
self.update_access_time(child_id);
|
||||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||||
@ -280,6 +300,7 @@ impl RadixTrie {
|
|||||||
node_id = self.find_(child_id, key, blocks);
|
node_id = self.find_(child_id, key, blocks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
node_id
|
node_id
|
||||||
}
|
}
|
||||||
@ -344,9 +365,11 @@ impl RadixTrie {
|
|||||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||||
// evicting prefixes from the cache in such a case.
|
// evicting prefixes from the cache in such a case.
|
||||||
let mut evicted = Vec::new();
|
let mut evicted = Vec::new();
|
||||||
|
tracing::debug!("Evicting in search of {n_blocks}");
|
||||||
|
|
||||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||||
let blocks_needed = n_blocks - evicted.len();
|
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||||
|
tracing::debug!("Evicting node {node_id:?} ");
|
||||||
|
|
||||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -368,8 +391,11 @@ impl RadixTrie {
|
|||||||
// the required number of blocks and leave the remaining blocks
|
// the required number of blocks and leave the remaining blocks
|
||||||
// untouched.
|
// untouched.
|
||||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||||
node.key.truncate(node.blocks.len() - blocks_needed);
|
|
||||||
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
|
let truncate_blocks = node.blocks.len() - blocks_needed;
|
||||||
|
let truncate_tokens = truncate_blocks * self.block_size;
|
||||||
|
node.key.truncate(truncate_tokens);
|
||||||
|
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||||
self.leaves.insert((last_access, node_id));
|
self.leaves.insert((last_access, node_id));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -400,11 +426,10 @@ impl RadixTrie {
|
|||||||
// the part of the prefix that is already in the trie to detect
|
// the part of the prefix that is already in the trie to detect
|
||||||
// mismatches.
|
// mismatches.
|
||||||
|
|
||||||
if tokens.len() != blocks.len() * self.block_size {
|
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
||||||
return Err(TrieError::BlockTokenCountMismatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
|
let node_key = hash(&tokens[..self.block_size]);
|
||||||
|
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
||||||
self.update_access_time(child_id);
|
self.update_access_time(child_id);
|
||||||
let child = self
|
let child = self
|
||||||
.nodes
|
.nodes
|
||||||
@ -452,14 +477,15 @@ impl RadixTrie {
|
|||||||
.get_mut(node_id)
|
.get_mut(node_id)
|
||||||
.expect("Node to-be split does not exist");
|
.expect("Node to-be split does not exist");
|
||||||
let mut parent_key = node.key.split_off(prefix_len);
|
let mut parent_key = node.key.split_off(prefix_len);
|
||||||
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
let prefix_blocks = prefix_len / self.block_size;
|
||||||
|
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
||||||
|
|
||||||
// Move first part of the prefix to the parent. We swap to avoid
|
// Move first part of the prefix to the parent. We swap to avoid
|
||||||
// an allocation + copy for both splits of the key/blocks.
|
// an allocation + copy for both splits of the key/blocks.
|
||||||
std::mem::swap(&mut node.key, &mut parent_key);
|
std::mem::swap(&mut node.key, &mut parent_key);
|
||||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||||
|
|
||||||
let node_key = node.key[0];
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
|
||||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||||
@ -484,7 +510,7 @@ impl RadixTrie {
|
|||||||
) -> NodeId {
|
) -> NodeId {
|
||||||
let key = key.into();
|
let key = key.into();
|
||||||
let blocks = blocks.into();
|
let blocks = blocks.into();
|
||||||
let first = key[0];
|
let first = hash(&key[..self.block_size]);
|
||||||
|
|
||||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
let child_id = self.nodes.insert(child);
|
let child_id = self.nodes.insert(child);
|
||||||
@ -496,10 +522,10 @@ impl RadixTrie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Add a node to the parent.
|
/// Add a node to the parent.
|
||||||
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
if parent.children.insert(first, child_id).is_none() {
|
if parent.children.insert(hash, child_id).is_none() {
|
||||||
// Only increase reference count if child does not replace another child.
|
// Only increase reference count if child does not replace another child.
|
||||||
self.incref(parent_id)
|
self.incref(parent_id)
|
||||||
.expect("Failed to increase parent refcount");
|
.expect("Failed to increase parent refcount");
|
||||||
@ -517,7 +543,9 @@ impl RadixTrie {
|
|||||||
);
|
);
|
||||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
parent.children.remove(&node.key[0]);
|
|
||||||
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
parent.children.remove(&node_key);
|
||||||
self.decref(parent_id)
|
self.decref(parent_id)
|
||||||
.expect("Failed to decrease parent refcount");
|
.expect("Failed to decrease parent refcount");
|
||||||
node
|
node
|
||||||
@ -571,7 +599,7 @@ impl RadixTrie {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct TrieNode {
|
struct TrieNode {
|
||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
children: HashMap<u32, NodeId>,
|
children: HashMap<u64, NodeId>,
|
||||||
key: Vec<u32>,
|
key: Vec<u32>,
|
||||||
last_accessed: u64,
|
last_accessed: u64,
|
||||||
parent: Option<NodeId>,
|
parent: Option<NodeId>,
|
||||||
|
@ -16,7 +16,6 @@ path = "src/main.rs"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
average = "0.14"
|
average = "0.14"
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
crossterm = "0.27"
|
|
||||||
float-ord = "0.3.2"
|
float-ord = "0.3.2"
|
||||||
serde = {version = "1.0.188", features = ["derive"]}
|
serde = {version = "1.0.188", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
@ -25,7 +24,7 @@ text-generation-client = { path = "../backends/client" }
|
|||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||||
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
ratatui = "0.28.1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||||
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
and powered by [Ratatui](https://github.com/ratatui/ratatui).
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||||
use crate::generation::{Decode, Message, Prefill};
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
use text_generation_client::ClientError;
|
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
||||||
use tokio::sync::mpsc;
|
use ratatui::style::{Color, Modifier, Style};
|
||||||
use tui::backend::Backend;
|
use ratatui::text::{Line, Span};
|
||||||
use tui::layout::{Alignment, Constraint, Direction, Layout};
|
use ratatui::widgets::{
|
||||||
use tui::style::{Color, Modifier, Style};
|
|
||||||
use tui::text::{Line, Span};
|
|
||||||
use tui::widgets::{
|
|
||||||
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||||
};
|
};
|
||||||
use tui::{symbols, Frame};
|
use ratatui::{symbols, Frame};
|
||||||
|
use text_generation_client::ClientError;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
/// TUI powered App
|
/// TUI powered App
|
||||||
pub(crate) struct App {
|
pub(crate) struct App {
|
||||||
@ -153,7 +152,7 @@ impl App {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Render frame
|
/// Render frame
|
||||||
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
pub fn render(&mut self, f: &mut Frame) {
|
||||||
let batch_progress =
|
let batch_progress =
|
||||||
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
|
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||||
let run_progress =
|
let run_progress =
|
||||||
@ -172,7 +171,7 @@ impl App {
|
|||||||
]
|
]
|
||||||
.as_ref(),
|
.as_ref(),
|
||||||
)
|
)
|
||||||
.split(f.size());
|
.split(f.area());
|
||||||
|
|
||||||
// Top row horizontal layout
|
// Top row horizontal layout
|
||||||
let top = Layout::default()
|
let top = Layout::default()
|
||||||
@ -239,7 +238,7 @@ impl App {
|
|||||||
f.render_widget(helper, row5[0]);
|
f.render_widget(helper, row5[0]);
|
||||||
|
|
||||||
// Batch tabs
|
// Batch tabs
|
||||||
let titles = self
|
let titles: Vec<Line> = self
|
||||||
.data
|
.data
|
||||||
.batch_size
|
.batch_size
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||||
use crossterm::event;
|
use ratatui::crossterm::event;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
|
@ -6,13 +6,13 @@ mod utils;
|
|||||||
|
|
||||||
use crate::app::App;
|
use crate::app::App;
|
||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use ratatui::backend::CrosstermBackend;
|
||||||
|
use ratatui::crossterm::ExecutableCommand;
|
||||||
|
use ratatui::Terminal;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
|
||||||
use tui::Terminal;
|
|
||||||
|
|
||||||
/// Run benchmarking app
|
/// Run benchmarking app
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -50,9 +50,9 @@ pub async fn run(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
crossterm::terminal::enable_raw_mode()?;
|
ratatui::crossterm::terminal::enable_raw_mode()?;
|
||||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
|
||||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
|
||||||
|
|
||||||
// Initialize terminal
|
// Initialize terminal
|
||||||
let mut terminal = {
|
let mut terminal = {
|
||||||
@ -128,9 +128,9 @@ pub async fn run(
|
|||||||
let _ = shutdown_guard_receiver.recv().await;
|
let _ = shutdown_guard_receiver.recv().await;
|
||||||
|
|
||||||
// Revert terminal to original view
|
// Revert terminal to original view
|
||||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
|
||||||
crossterm::terminal::disable_raw_mode()?;
|
ratatui::crossterm::terminal::disable_raw_mode()?;
|
||||||
io::stdout().execute(crossterm::cursor::Show)?;
|
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
|
||||||
|
|
||||||
let parameters_table = table::parameters_table(
|
let parameters_table = table::parameters_table(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
|||||||
function: dict
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
type: str
|
||||||
|
text: Optional[str] = None
|
||||||
|
image_url: Any = None
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
# Role of the message sender
|
# Role of the message sender
|
||||||
role: str
|
role: str
|
||||||
# Content of the message
|
# Content of the message
|
||||||
content: Optional[str] = None
|
content: Optional[Union[str, List[Chunk]]] = None
|
||||||
# Optional name of the message sender
|
# Optional name of the message sender
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
# Tool calls associated with the chat completion
|
# Tool calls associated with the chat completion
|
||||||
@ -168,7 +174,7 @@ class ChatCompletionComplete(BaseModel):
|
|||||||
# Log probabilities for the chat completion
|
# Log probabilities for the chat completion
|
||||||
logprobs: Optional[Any]
|
logprobs: Optional[Any]
|
||||||
# Reason for completion
|
# Reason for completion
|
||||||
finish_reason: str
|
finish_reason: Optional[str]
|
||||||
# Usage details of the chat completion
|
# Usage details of the chat completion
|
||||||
usage: Optional[Any] = None
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
@ -191,6 +197,7 @@ class ChatCompletionChunk(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
system_fingerprint: str
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "2.2.1-dev0"
|
"version": "2.3.2-dev0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
@ -742,6 +742,14 @@
|
|||||||
},
|
},
|
||||||
"system_fingerprint": {
|
"system_fingerprint": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/Usage"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -937,6 +945,14 @@
|
|||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"stream_options": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/StreamOptions"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"format": "float",
|
"format": "float",
|
||||||
@ -1912,6 +1928,19 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"StreamOptions": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"include_usage"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"include_usage": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.",
|
||||||
|
"example": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"StreamResponse": {
|
"StreamResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -10,7 +10,7 @@ This diagram shows well there are these separate components:
|
|||||||
|
|
||||||
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
||||||
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
||||||
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||||
|
|
||||||
The router and the model server can be two different machines, they do not need to be deployed together.
|
The router and the model server can be two different machines, they do not need to be deployed together.
|
||||||
|
|
||||||
|
@ -36,7 +36,13 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
|||||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
```
|
```
|
||||||
|
|
||||||
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
|
To specify model revision, use `adapter_id@revision`, as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||||
|
```
|
||||||
|
|
||||||
|
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||||
@ -72,6 +78,22 @@ curl 127.0.0.1:3000/generate \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl 127.0.0.1:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs": "Hello who are you?",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"adapter_id": "myadapter"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||||
|
|
||||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
||||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \
|
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
|||||||
|
|
||||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
## Custom PagedAttention
|
||||||
|
|
||||||
|
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||||
|
|
||||||
|
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||||
|
|
||||||
## Unsupported features
|
## Unsupported features
|
||||||
|
|
||||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
|
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
|
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -11,10 +11,19 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
If you want to serve gated or private models, which provide
|
||||||
|
controlled access to sensitive or proprietary content, refer to
|
||||||
|
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
||||||
|
for detailed instructions.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
### Supported hardware
|
### Supported hardware
|
||||||
|
|
||||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||||
|
@ -55,7 +55,9 @@ Options:
|
|||||||
## QUANTIZE
|
## QUANTIZE
|
||||||
```shell
|
```shell
|
||||||
--quantize <QUANTIZE>
|
--quantize <QUANTIZE>
|
||||||
Whether you want the model to be quantized
|
Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.
|
||||||
|
|
||||||
|
Marlin kernels will be used automatically for GPTQ/AWQ models.
|
||||||
|
|
||||||
[env: QUANTIZE=]
|
[env: QUANTIZE=]
|
||||||
|
|
||||||
@ -87,6 +89,15 @@ Options:
|
|||||||
[env: DTYPE=]
|
[env: DTYPE=]
|
||||||
[possible values: float16, bfloat16]
|
[possible values: float16, bfloat16]
|
||||||
|
|
||||||
|
```
|
||||||
|
## KV_CACHE_DTYPE
|
||||||
|
```shell
|
||||||
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
|
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
||||||
|
|
||||||
|
[env: KV_CACHE_DTYPE=]
|
||||||
|
[possible values: fp8_e5m2]
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
```shell
|
```shell
|
||||||
|
@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
|
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
|
||||||
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
||||||
- [Phi](https://huggingface.co/microsoft/phi-1_5)
|
- [Phi](https://huggingface.co/microsoft/phi-1_5)
|
||||||
|
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
|
||||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||||
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
|
||||||
@ -34,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||||
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
||||||
|
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
||||||
|
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
38
flake.lock
38
flake.lock
@ -479,11 +479,11 @@
|
|||||||
"systems": "systems_6"
|
"systems": "systems_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1726560853,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -497,11 +497,11 @@
|
|||||||
"systems": "systems_7"
|
"systems": "systems_7"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1726560853,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -718,11 +718,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_6": {
|
"nixpkgs_6": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1724915739,
|
"lastModified": 1727675176,
|
||||||
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
|
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
|
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -853,11 +853,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726021481,
|
"lastModified": 1727836133,
|
||||||
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
|
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
|
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -978,16 +978,16 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1725950569,
|
"lastModified": 1728029332,
|
||||||
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
|
"narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=",
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"repo": "tgi-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
|
"rev": "98049f853346ca780b81fee730715c90d33ac2b4",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"repo": "tgi-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
94
flake.nix
94
flake.nix
@ -5,7 +5,7 @@
|
|||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
@ -37,6 +37,7 @@
|
|||||||
overlays = [
|
overlays = [
|
||||||
rust-overlay.overlays.default
|
rust-overlay.overlays.default
|
||||||
tgi-nix.overlays.default
|
tgi-nix.overlays.default
|
||||||
|
(import nix/overlay.nix)
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||||
@ -67,8 +68,37 @@
|
|||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
|
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
|
checks = {
|
||||||
|
rust =
|
||||||
|
with pkgs;
|
||||||
|
rustPlatform.buildRustPackage {
|
||||||
|
name = "rust-checks";
|
||||||
|
src = ./.;
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./Cargo.lock;
|
||||||
|
};
|
||||||
|
buildInputs = [ openssl.dev ];
|
||||||
|
nativeBuildInputs = [
|
||||||
|
clippy
|
||||||
|
pkg-config
|
||||||
|
protobuf
|
||||||
|
python3
|
||||||
|
rustfmt
|
||||||
|
];
|
||||||
|
buildPhase = ''
|
||||||
|
cargo check
|
||||||
|
'';
|
||||||
|
checkPhase = ''
|
||||||
|
cargo fmt -- --check
|
||||||
|
cargo test -j $NIX_BUILD_CORES
|
||||||
|
cargo clippy
|
||||||
|
'';
|
||||||
|
installPhase = "touch $out";
|
||||||
|
};
|
||||||
|
};
|
||||||
formatter = pkgs.nixfmt-rfc-style;
|
formatter = pkgs.nixfmt-rfc-style;
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = pure;
|
default = pure;
|
||||||
@ -84,10 +114,11 @@
|
|||||||
test = mkShell {
|
test = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
# benchmark
|
benchmark
|
||||||
# launcher
|
launcher
|
||||||
# router
|
router
|
||||||
server
|
server
|
||||||
|
client
|
||||||
openssl.dev
|
openssl.dev
|
||||||
pkg-config
|
pkg-config
|
||||||
cargo
|
cargo
|
||||||
@ -102,52 +133,17 @@
|
|||||||
pre-commit
|
pre-commit
|
||||||
ruff
|
ruff
|
||||||
]);
|
]);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
impure = mkShell {
|
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||||
buildInputs =
|
|
||||||
[
|
|
||||||
openssl.dev
|
|
||||||
pkg-config
|
|
||||||
(rust-bin.stable.latest.default.override {
|
|
||||||
extensions = [
|
|
||||||
"rust-analyzer"
|
|
||||||
"rust-src"
|
|
||||||
];
|
|
||||||
})
|
|
||||||
protobuf
|
|
||||||
]
|
|
||||||
++ (with python3.pkgs; [
|
|
||||||
venvShellHook
|
|
||||||
docker
|
|
||||||
pip
|
|
||||||
ipdb
|
|
||||||
click
|
|
||||||
pyright
|
|
||||||
pytest
|
|
||||||
pytest-asyncio
|
|
||||||
ruff
|
|
||||||
syrupy
|
|
||||||
]);
|
|
||||||
|
|
||||||
inputsFrom = [ server ];
|
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||||
|
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||||
venvDir = "./.venv";
|
|
||||||
|
|
||||||
postVenvCreation = ''
|
|
||||||
unset SOURCE_DATE_EPOCH
|
|
||||||
( cd server ; python -m pip install --no-dependencies -e . )
|
|
||||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
|
||||||
'';
|
|
||||||
postShellHook = ''
|
|
||||||
unset SOURCE_DATE_EPOCH
|
|
||||||
export PATH=$PATH:~/.cargo/bin
|
|
||||||
'';
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
packages.default = pkgs.writeShellApplication {
|
packages = rec {
|
||||||
|
default = pkgs.writeShellApplication {
|
||||||
name = "text-generation-inference";
|
name = "text-generation-inference";
|
||||||
runtimeInputs = [
|
runtimeInputs = [
|
||||||
server
|
server
|
||||||
@ -157,6 +153,16 @@
|
|||||||
${launcher}/bin/text-generation-launcher "$@"
|
${launcher}/bin/text-generation-launcher "$@"
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
dockerImage = pkgs.callPackage nix/docker.nix {
|
||||||
|
text-generation-inference = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
|
||||||
|
text-generation-inference = default;
|
||||||
|
stream = true;
|
||||||
|
};
|
||||||
|
};
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -336,12 +336,14 @@ def launcher(event_loop):
|
|||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
lora_adapters: Optional[List[str]] = None,
|
lora_adapters: Optional[List[str]] = None,
|
||||||
cuda_graphs: Optional[List[int]] = None,
|
cuda_graphs: Optional[List[int]] = None,
|
||||||
|
attention: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
@ -374,6 +376,9 @@ def launcher(event_loop):
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if kv_cache_dtype is not None:
|
||||||
|
args.append("--kv-cache-dtype")
|
||||||
|
args.append(kv_cache_dtype)
|
||||||
if revision is not None:
|
if revision is not None:
|
||||||
args.append("--revision")
|
args.append("--revision")
|
||||||
args.append(revision)
|
args.append(revision)
|
||||||
@ -401,6 +406,8 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
if attention is not None:
|
||||||
|
env["ATTENTION"] = attention
|
||||||
|
|
||||||
with tempfile.TemporaryFile("w+") as tmp:
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
@ -431,12 +438,14 @@ def launcher(event_loop):
|
|||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
lora_adapters: Optional[List[str]] = None,
|
lora_adapters: Optional[List[str]] = None,
|
||||||
cuda_graphs: Optional[List[int]] = None,
|
cuda_graphs: Optional[List[int]] = None,
|
||||||
|
attention: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
@ -452,6 +461,9 @@ def launcher(event_loop):
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if kv_cache_dtype is not None:
|
||||||
|
args.append("--kv-cache-dtype")
|
||||||
|
args.append(kv_cache_dtype)
|
||||||
if revision is not None:
|
if revision is not None:
|
||||||
args.append("--revision")
|
args.append("--revision")
|
||||||
args.append(revision)
|
args.append(revision)
|
||||||
@ -491,6 +503,8 @@ def launcher(event_loop):
|
|||||||
}
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
if attention is not None:
|
||||||
|
env["ATTENTION"] = attention
|
||||||
|
|
||||||
if HF_TOKEN is not None:
|
if HF_TOKEN is not None:
|
||||||
env["HF_TOKEN"] = HF_TOKEN
|
env["HF_TOKEN"] = HF_TOKEN
|
||||||
@ -522,6 +536,7 @@ def launcher(event_loop):
|
|||||||
devices=devices,
|
devices=devices,
|
||||||
volumes=volumes,
|
volumes=volumes,
|
||||||
ports={"80/tcp": port},
|
ports={"80/tcp": port},
|
||||||
|
healthcheck={"timeout": int(10 * 1e9)},
|
||||||
shm_size="1G",
|
shm_size="1G",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -582,7 +597,6 @@ def generate_multi():
|
|||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
arange = np.arange(len(prompts))
|
arange = np.arange(len(prompts))
|
||||||
|
@ -0,0 +1,206 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "**",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "Deep",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " Learning",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": ":",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " An",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " Overview",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "**\n",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "================================",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "=====",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "\n\n",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 40,
|
||||||
|
"total_tokens": 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
@ -24,13 +24,13 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -2.03125,
|
"logprob": -2.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 109,
|
"id": 109,
|
||||||
"logprob": -1.8671875,
|
"logprob": -1.90625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n\n"
|
"text": "\n\n"
|
||||||
},
|
},
|
||||||
@ -42,48 +42,48 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2121,
|
"id": 2121,
|
||||||
"logprob": -1.8125,
|
"logprob": -1.796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -0.24121094,
|
"logprob": -0.24511719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -0.100097656,
|
"logprob": -0.09326172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 603,
|
"id": 603,
|
||||||
"logprob": -0.9453125,
|
"logprob": -0.95703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 476,
|
"id": 1671,
|
||||||
"logprob": -1.703125,
|
"logprob": -1.5859375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " used"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4551,
|
"id": 577,
|
||||||
"logprob": -2.453125,
|
"logprob": -0.39257812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " to"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 674,
|
"id": 3853,
|
||||||
"logprob": -0.796875,
|
"logprob": -1.25,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " form\n\nThe test request form is a document that"
|
"generated_text": " form\n\nThe test request form is used to request"
|
||||||
}
|
}
|
||||||
|
@ -11,12 +11,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2015,
|
"id": 2015,
|
||||||
"logprob": -9.640625,
|
"logprob": -9.6484375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -10.375,
|
"logprob": -10.3671875,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -24,19 +24,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 604,
|
"id": 604,
|
||||||
"logprob": -0.2824707,
|
"logprob": -0.28271484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " for"
|
"text": " for"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 573,
|
"id": 573,
|
||||||
"logprob": -0.19030762,
|
"logprob": -0.18493652,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16819,
|
"id": 16819,
|
||||||
"logprob": -1.4892578,
|
"logprob": -1.4804688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " detection"
|
"text": " detection"
|
||||||
},
|
},
|
||||||
@ -46,44 +46,44 @@
|
|||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 573,
|
|
||||||
"logprob": -2.0195312,
|
|
||||||
"special": false,
|
|
||||||
"text": " the"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 8566,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " presence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 689,
|
|
||||||
"logprob": -0.16491699,
|
|
||||||
"special": false,
|
|
||||||
"text": " or"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 14862,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " absence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 576,
|
|
||||||
"logprob": -0.9946289,
|
|
||||||
"special": false,
|
|
||||||
"text": " of"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 671,
|
"id": 671,
|
||||||
"logprob": -0.5263672,
|
"logprob": -2.1738281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " an"
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24646,
|
||||||
|
"logprob": -3.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": " RNA"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12369,
|
||||||
|
"logprob": -0.19299316,
|
||||||
|
"special": false,
|
||||||
|
"text": " virus"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 575,
|
||||||
|
"logprob": -0.10632324,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6022,
|
||||||
|
"logprob": -0.98095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " patients"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -1.3095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " who"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request for the detection of the presence or absence of an"
|
"generated_text": "Test request for the detection of an RNA virus in patients who"
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,104 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.079956055,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.2763672,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37548828,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02885437,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.2565918,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0063438416,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3056641,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.6035156,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 3,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -22.96875,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -10.71875,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -4.8398438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 720,
|
||||||
|
"logprob": -0.4411621,
|
||||||
|
"special": false,
|
||||||
|
"text": " \n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -0.35864258,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 128001,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "<|end_of_text|>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is deep learning? \n "
|
||||||
|
}
|
@ -0,0 +1,418 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
}
|
||||||
|
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,114 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4648438,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6005859,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39526367,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18774414,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.96484375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16540527,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08886719,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.75878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11242676,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7939453,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -17.234375,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -7.4375,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.8046875,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.33032227,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 1313,
|
||||||
|
"logprob": -2.3613281,
|
||||||
|
"special": false,
|
||||||
|
"text": "It"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3969,
|
||||||
|
"logprob": -0.7285156,
|
||||||
|
"special": false,
|
||||||
|
"text": " seems"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -1.3466797,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 528,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " me"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28725,
|
||||||
|
"logprob": -1.6757812,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 513,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 368,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " you"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28742,
|
||||||
|
"logprob": -2.4921875,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 267,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "re"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is gradient descent?\n\nIt seems to me, that if you're"
|
||||||
|
}
|
@ -0,0 +1,458 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4648438,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6005859,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39526367,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18774414,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.96484375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16369629,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.0881958,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.76708984,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.57373047,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11291504,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.79589844,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.1694336,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34350586,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4677734,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6015625,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39453125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18713379,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9628906,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.0032176971,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16540527,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08898926,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5708008,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11401367,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7963867,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17028809,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.140625,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4658203,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6796875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.5898438,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.3955078,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.64501953,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18493652,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9580078,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.0032176971,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16552734,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08874512,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.75878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11236572,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.79541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4658203,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6796875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.5947266,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39648438,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.6464844,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18688965,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9609375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16601562,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.088134766,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.7597656,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5708008,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11291504,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7944336,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34399414,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8808594,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37280273,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.26098633,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017137527,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2695312,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.34838867,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13940,
|
||||||
|
"logprob": -0.38916016,
|
||||||
|
"special": false,
|
||||||
|
"text": "``"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28832,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "`"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3371,
|
||||||
|
"logprob": -1.2529297,
|
||||||
|
"special": false,
|
||||||
|
"text": "json"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28751,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "{"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2287,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 345,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3134,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request\n```json\n{\n \"request"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,109 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.7133789,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.9296875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.048919678,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -3.0078125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.8105469,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.84521484,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25584,
|
||||||
|
"logprob": -0.017028809,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 993,
|
||||||
|
"logprob": -0.0027313232,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.023254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -2.0623207e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 937,
|
||||||
|
"logprob": -0.17578125,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2098,
|
||||||
|
"logprob": -0.00011539459,
|
||||||
|
"special": false,
|
||||||
|
"text": "order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13883,
|
||||||
|
"logprob": -0.47436523,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5687,
|
||||||
|
"logprob": -0.00027680397,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -6.4960938,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -5.1484375,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -4.0351562,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -5.2265625,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 10994,
|
||||||
|
"logprob": -1.1542969,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29991,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 739,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " It"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2444,
|
||||||
|
"logprob": -0.42260742,
|
||||||
|
"special": false,
|
||||||
|
"text": " seems"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 366,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " you"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29915,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 276,
|
||||||
|
"logprob": -0.9838867,
|
||||||
|
"special": false,
|
||||||
|
"text": "re"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3211,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " address"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 292,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "ing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.15124512,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||||
|
}
|
@ -0,0 +1,438 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.7133789,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.9296875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.048919678,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -3.0078125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.8105469,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.84521484,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25584,
|
||||||
|
"logprob": -0.017028809,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 993,
|
||||||
|
"logprob": -0.0028476715,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.023971558,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -2.0384789e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5229492,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 937,
|
||||||
|
"logprob": -0.17602539,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2098,
|
||||||
|
"logprob": -0.000116467476,
|
||||||
|
"special": false,
|
||||||
|
"text": "order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13883,
|
||||||
|
"logprob": -0.47436523,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5687,
|
||||||
|
"logprob": -0.00027871132,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.7128906,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.9375,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.05053711,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -3.0058594,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.8242188,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.84521484,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25584,
|
||||||
|
"logprob": -0.018859863,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 993,
|
||||||
|
"logprob": -0.002822876,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.023254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -2.0384789e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5229492,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 937,
|
||||||
|
"logprob": -0.17126465,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2098,
|
||||||
|
"logprob": -0.0001155138,
|
||||||
|
"special": false,
|
||||||
|
"text": "order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13883,
|
||||||
|
"logprob": -0.47436523,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5687,
|
||||||
|
"logprob": -0.00027036667,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.71484375,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.9375,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.049346924,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -3.0078125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.8242188,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.86328125,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25584,
|
||||||
|
"logprob": -0.017196655,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 993,
|
||||||
|
"logprob": -0.0028438568,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.023254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -2.026558e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5229492,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 937,
|
||||||
|
"logprob": -0.17602539,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2098,
|
||||||
|
"logprob": -0.00011622906,
|
||||||
|
"special": false,
|
||||||
|
"text": "order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13883,
|
||||||
|
"logprob": -0.48608398,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5687,
|
||||||
|
"logprob": -0.00027894974,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.7192383,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.9375,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.050445557,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -3.0078125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.8242188,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.8276367,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 25584,
|
||||||
|
"logprob": -0.01727295,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 993,
|
||||||
|
"logprob": -0.0027542114,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26815,
|
||||||
|
"logprob": -0.023254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -2.0384789e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5229492,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 937,
|
||||||
|
"logprob": -0.17126465,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2098,
|
||||||
|
"logprob": -0.00011301041,
|
||||||
|
"special": false,
|
||||||
|
"text": "order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13883,
|
||||||
|
"logprob": -0.48608398,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5687,
|
||||||
|
"logprob": -0.00027894974,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,106 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a bustling city, a chicken named Cluck",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a bustling city, a chicken named Cluck",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727556016,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
}
|
@ -3,9 +3,7 @@ import requests
|
|||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import Completion, ChatCompletionChunk
|
||||||
Completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
async def test_flash_llama_completion_stream_usage(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
had_usage = False
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatCompletionChunk(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
has_usage = c["usage"] is not None
|
||||||
|
assert not had_usage
|
||||||
|
if has_usage:
|
||||||
|
had_usage = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert had_usage
|
||||||
|
assert (
|
||||||
|
string
|
||||||
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
)
|
||||||
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
had_usage = False
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatCompletionChunk(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
has_usage = c["usage"] is not None
|
||||||
|
assert not had_usage
|
||||||
|
if has_usage:
|
||||||
|
had_usage = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert not had_usage
|
||||||
|
assert (
|
||||||
|
string
|
||||||
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
77
integration-tests/models/test_flash_llama_fp8_kv_cache.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
|
||||||
|
await flash_llama_fp8_kv_cache_handle.health(300)
|
||||||
|
return flash_llama_fp8_kv_cache_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
|
||||||
|
response = await flash_llama_fp8_kv_cache.generate(
|
||||||
|
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " Deep learning is a subset of machine learning that is"
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache_all_params(
|
||||||
|
flash_llama_fp8_kv_cache, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_fp8_kv_cache.generate(
|
||||||
|
"What is deep learning?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache_load(
|
||||||
|
flash_llama_fp8_kv_cache, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== " Deep learning is a subset of machine learning that is"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||||
|
assert responses == response_snapshot
|
File diff suppressed because one or more lines are too long
75
integration-tests/models/test_flash_mixtral.py
Normal file
75
integration-tests/models/test_flash_mixtral.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_mixtral_handle(launcher):
|
||||||
|
with launcher("mistralai/Mixtral-8x7B-v0.1", num_shard=8) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_mixtral(flash_mixtral_handle):
|
||||||
|
await flash_mixtral_handle.health(300)
|
||||||
|
return flash_mixtral_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral(flash_mixtral, response_snapshot):
|
||||||
|
response = await flash_mixtral.generate(
|
||||||
|
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_all_params(flash_mixtral, response_snapshot):
|
||||||
|
response = await flash_mixtral.generate(
|
||||||
|
"What is gradient descent?\n\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is gradient descent?\n\nIt seems to me, that if you're"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_load(flash_mixtral, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_mixtral, "What is gradient descent?\n\n", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert responses[0].details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
60
integration-tests/models/test_flash_mixtral_gptq.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_mixtral_gptq_handle(launcher):
|
||||||
|
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
||||||
|
await flash_mixtral_gptq_handle.health(300)
|
||||||
|
return flash_mixtral_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||||
|
response = await flash_mixtral_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||||
|
response = await flash_mixtral_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq_load(
|
||||||
|
flash_mixtral_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
75
integration-tests/models/test_flash_phi35_moe.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_phi35_moe_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
num_shard=4,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_phi35_moe(flash_phi35_moe_handle):
|
||||||
|
await flash_phi35_moe_handle.health(300)
|
||||||
|
return flash_phi35_moe_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
||||||
|
response = await flash_phi35_moe.generate(
|
||||||
|
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "Gradient descent is a first-order optimization algorithm"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||||
|
response = await flash_phi35_moe.generate(
|
||||||
|
"What is gradient descent?\n\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is gradient descent?\n\nHello! It seems you're addressing a"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert responses[0].details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "Gradient descent is a first-order optimization algorithm"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
105
integration-tests/models/test_mllama.py
Normal file
105
integration-tests/models/test_mllama.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mllama_handle(launcher):
|
||||||
|
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def mllama(mllama_handle):
|
||||||
|
await mllama_handle.health(300)
|
||||||
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
|
# chicken = get_chicken()
|
||||||
|
response = await mllama.chat(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.usage == {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60,
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "In a bustling city, a chicken named Cluck"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||||
|
futures = [
|
||||||
|
mllama.chat(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
responses = await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
generated_texts = [response.choices[0].message.content for response in responses]
|
||||||
|
|
||||||
|
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert generated_texts, all(
|
||||||
|
[text == generated_texts[0] for text in generated_texts]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -4,7 +4,9 @@ import pytest
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_grammar_tools_handle(launcher):
|
def flash_llama_grammar_tools_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
num_shard=2,
|
||||||
|
disable_grammar_support=False,
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 48
|
assert count == 28
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,11 +12,13 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
|
|||||||
hf-hub = "0.3.2"
|
hf-hub = "0.3.2"
|
||||||
nix = { version = "0.28.0", features = ["signal"] }
|
nix = { version = "0.28.0", features = ["signal"] }
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
pyo3 = { workspace = true }
|
||||||
serde = { version = "1.0.188", features = ["derive"] }
|
serde = { version = "1.0.188", features = ["derive"] }
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
thiserror = "1.0.59"
|
thiserror = "1.0.59"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
regex = "1.11.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
|
21
launcher/src/gpu.rs
Normal file
21
launcher/src/gpu.rs
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
pub fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
|
||||||
|
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||||
|
let torch = py.import_bound("torch.cuda")?;
|
||||||
|
let get_device_capability = torch.getattr("get_device_capability")?;
|
||||||
|
get_device_capability.call0()?.extract()
|
||||||
|
};
|
||||||
|
|
||||||
|
match pyo3::Python::with_gil(py_get_capability) {
|
||||||
|
Ok((major, minor)) if major < 0 || minor < 0 => {
|
||||||
|
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Ok((major, minor)) => Some((major as usize, minor as usize)),
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("Cannot determine GPU compute capability: {}", err);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ use hf_hub::{
|
|||||||
};
|
};
|
||||||
use nix::sys::signal::{self, Signal};
|
use nix::sys::signal::{self, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
|
use regex::Regex;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
@ -26,6 +27,7 @@ use thiserror::Error;
|
|||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
mod gpu;
|
||||||
|
|
||||||
fn get_config(
|
fn get_config(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
@ -65,6 +67,7 @@ fn get_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
|
let compute_capability = gpu::get_cuda_capability();
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
@ -77,6 +80,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||||
|
"paged"
|
||||||
|
} else {
|
||||||
|
"flashdecoding"
|
||||||
|
};
|
||||||
|
|
||||||
match config.head_dim {
|
match config.head_dim {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||||
@ -89,10 +99,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"Forcing flash decoding because model {} requires it",
|
"Forcing attention to '{fallback_attention}' because model {} requires it",
|
||||||
config.model_type.as_ref().unwrap()
|
config.model_type.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some(fallback_attention.to_string());
|
||||||
|
}
|
||||||
|
if fallback_attention == "paged" && prefix_caching.is_none() {
|
||||||
|
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some("t5") => {}
|
Some("t5") => {}
|
||||||
@ -101,8 +115,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some(fallback_attention.to_string());
|
||||||
}
|
}
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
@ -110,8 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
|
||||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||||
|
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||||
|
|
||||||
(prefix_caching, attention)
|
(prefix_caching, attention)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,6 +301,22 @@ impl std::fmt::Display for Dtype {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum KVCacheDtype {
|
||||||
|
#[clap(name = "fp8_e5m2")]
|
||||||
|
Fp8e5m2,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for KVCacheDtype {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
KVCacheDtype::Fp8e5m2 => {
|
||||||
|
write!(f, "fp8_e5m2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum RopeScaling {
|
enum RopeScaling {
|
||||||
Linear,
|
Linear,
|
||||||
@ -367,7 +399,11 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
|
||||||
/// Whether you want the model to be quantized.
|
/// Quantization method to use for the model. It is not necessary to specify this option
|
||||||
|
/// for pre-quantized models, since the quantization method is read from the model
|
||||||
|
/// configuration.
|
||||||
|
///
|
||||||
|
/// Marlin kernels will be used automatically for GPTQ/AWQ models.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
@ -382,6 +418,12 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
|
||||||
|
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||||
|
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||||
|
/// the only supported value is `fp8_e5m2` on CUDA.
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
/// contributed in a newer revision.
|
/// contributed in a newer revision.
|
||||||
@ -650,6 +692,7 @@ fn shard_manager(
|
|||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
speculate: Option<usize>,
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
@ -723,6 +766,11 @@ fn shard_manager(
|
|||||||
shard_args.push(dtype.to_string())
|
shard_args.push(dtype.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||||
|
shard_args.push("--kv-cache-dtype".to_string());
|
||||||
|
shard_args.push(kv_cache_dtype.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = revision {
|
if let Some(revision) = revision {
|
||||||
shard_args.push("--revision".to_string());
|
shard_args.push("--revision".to_string());
|
||||||
@ -1034,6 +1082,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
|||||||
Ok(log) => log.trace(),
|
Ok(log) => log.trace(),
|
||||||
// For interactive debugging ?
|
// For interactive debugging ?
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
if LevelFilter::current() >= tracing::Level::DEBUG {
|
||||||
stdout.write_all(line).unwrap();
|
stdout.write_all(line).unwrap();
|
||||||
if lines.peek().is_some() {
|
if lines.peek().is_some() {
|
||||||
stdout.write_all(b"\n").unwrap();
|
stdout.write_all(b"\n").unwrap();
|
||||||
@ -1045,6 +1094,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_num_shards(
|
fn find_num_shards(
|
||||||
@ -1277,6 +1327,7 @@ fn spawn_shards(
|
|||||||
let otlp_service_name = args.otlp_service_name.clone();
|
let otlp_service_name = args.otlp_service_name.clone();
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
|
let kv_cache_dtype = args.kv_cache_dtype;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
@ -1295,6 +1346,7 @@ fn spawn_shards(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
@ -1787,14 +1839,37 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
if adapter.contains('=') {
|
if adapter.contains('=') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let adapter = adapter.trim();
|
||||||
|
|
||||||
|
// check if adapter has more than 1 '@'
|
||||||
|
if adapter.matches('@').count() > 1 {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"Invalid LoRA adapter format: {}",
|
||||||
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||||
|
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
||||||
|
if let Some(caps) = re.captures(adapter) {
|
||||||
|
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
||||||
|
let revision = caps.get(3).map(|m| m.as_str());
|
||||||
|
|
||||||
download_convert_model(
|
download_convert_model(
|
||||||
adapter,
|
adapter_id,
|
||||||
None,
|
revision,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
args.huggingface_hub_cache.as_deref(),
|
args.huggingface_hub_cache.as_deref(),
|
||||||
args.weights_cache_override.as_deref(),
|
args.weights_cache_override.as_deref(),
|
||||||
running.clone(),
|
running.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
} else {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"Invalid LoRA adapter format: {}",
|
||||||
|
adapter
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
21
nix/client.nix
Normal file
21
nix/client.nix
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
buildPythonPackage,
|
||||||
|
poetry-core,
|
||||||
|
huggingface-hub,
|
||||||
|
pydantic,
|
||||||
|
}:
|
||||||
|
|
||||||
|
buildPythonPackage {
|
||||||
|
name = "text-generation";
|
||||||
|
|
||||||
|
src = ../clients/python;
|
||||||
|
|
||||||
|
pyproject = true;
|
||||||
|
|
||||||
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
huggingface-hub
|
||||||
|
pydantic
|
||||||
|
];
|
||||||
|
}
|
23
nix/docker.nix
Normal file
23
nix/docker.nix
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
dockerTools,
|
||||||
|
cacert,
|
||||||
|
text-generation-inference,
|
||||||
|
stream ? false,
|
||||||
|
}:
|
||||||
|
|
||||||
|
let
|
||||||
|
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
|
||||||
|
in
|
||||||
|
build {
|
||||||
|
name = "tgi-docker";
|
||||||
|
tag = "latest";
|
||||||
|
config = {
|
||||||
|
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
|
||||||
|
Env = [
|
||||||
|
"HF_HOME=/data"
|
||||||
|
"PORT=80"
|
||||||
|
];
|
||||||
|
|
||||||
|
};
|
||||||
|
contents = [ cacert ];
|
||||||
|
}
|
54
nix/impure-shell.nix
Normal file
54
nix/impure-shell.nix
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
{
|
||||||
|
mkShell,
|
||||||
|
openssl,
|
||||||
|
pkg-config,
|
||||||
|
protobuf,
|
||||||
|
python3,
|
||||||
|
pyright,
|
||||||
|
redocly,
|
||||||
|
ruff,
|
||||||
|
rust-bin,
|
||||||
|
server,
|
||||||
|
}:
|
||||||
|
|
||||||
|
mkShell {
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
openssl.dev
|
||||||
|
pkg-config
|
||||||
|
(rust-bin.stable.latest.default.override {
|
||||||
|
extensions = [
|
||||||
|
"rust-analyzer"
|
||||||
|
"rust-src"
|
||||||
|
];
|
||||||
|
})
|
||||||
|
protobuf
|
||||||
|
pyright
|
||||||
|
redocly
|
||||||
|
ruff
|
||||||
|
]
|
||||||
|
++ (with python3.pkgs; [
|
||||||
|
venvShellHook
|
||||||
|
docker
|
||||||
|
pip
|
||||||
|
ipdb
|
||||||
|
click
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
|
syrupy
|
||||||
|
]);
|
||||||
|
|
||||||
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
|
venvDir = "./.venv";
|
||||||
|
|
||||||
|
postVenvCreation = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
( cd server ; python -m pip install --no-dependencies -e . )
|
||||||
|
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||||
|
'';
|
||||||
|
postShellHook = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
export PATH=$PATH:~/.cargo/bin
|
||||||
|
'';
|
||||||
|
}
|
41
nix/overlay.nix
Normal file
41
nix/overlay.nix
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
final: prev: {
|
||||||
|
# You can use this overlay to temporarily override packages for
|
||||||
|
# development. For permanent overrides, it's better to do this in
|
||||||
|
# our package flake:
|
||||||
|
#
|
||||||
|
# https://github.com/huggingface/text-generation-inference-nix
|
||||||
|
#
|
||||||
|
# Note that overriding packages that are in the transitive closure
|
||||||
|
# of many other packages (e.g. transformers) will require a large
|
||||||
|
# rebuild.
|
||||||
|
|
||||||
|
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
|
||||||
|
(
|
||||||
|
python-self: python-super: with python-self; {
|
||||||
|
# Python package override example:
|
||||||
|
# transformers = python-super.transformers.overrideAttrs (
|
||||||
|
# _: _: {
|
||||||
|
# src = final.fetchFromGitHub {
|
||||||
|
# owner = "huggingface";
|
||||||
|
# repo = "transformers";
|
||||||
|
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
|
||||||
|
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
|
||||||
|
# };
|
||||||
|
# }
|
||||||
|
# );
|
||||||
|
}
|
||||||
|
)
|
||||||
|
];
|
||||||
|
|
||||||
|
# Non-python package override example:
|
||||||
|
#
|
||||||
|
# ripgrep = prev.ripgrep.overrideAttrs (
|
||||||
|
# _: _: {
|
||||||
|
# src = final.fetchFromGitHub {
|
||||||
|
# owner = "BurntSushi";
|
||||||
|
# repo = "ripgrep";
|
||||||
|
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
|
||||||
|
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
|
||||||
|
# };
|
||||||
|
# });
|
||||||
|
}
|
@ -13,6 +13,7 @@
|
|||||||
flash-attn,
|
flash-attn,
|
||||||
flash-attn-layer-norm,
|
flash-attn-layer-norm,
|
||||||
flash-attn-rotary,
|
flash-attn-rotary,
|
||||||
|
flash-attn-v1,
|
||||||
grpc-interceptor,
|
grpc-interceptor,
|
||||||
grpcio-reflection,
|
grpcio-reflection,
|
||||||
grpcio-status,
|
grpcio-status,
|
||||||
@ -21,6 +22,7 @@
|
|||||||
loguru,
|
loguru,
|
||||||
mamba-ssm,
|
mamba-ssm,
|
||||||
marlin-kernels,
|
marlin-kernels,
|
||||||
|
moe-kernels,
|
||||||
opentelemetry-api,
|
opentelemetry-api,
|
||||||
opentelemetry-exporter-otlp,
|
opentelemetry-exporter-otlp,
|
||||||
opentelemetry-instrumentation-grpc,
|
opentelemetry-instrumentation-grpc,
|
||||||
@ -88,6 +90,7 @@ buildPythonPackage {
|
|||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
|
moe-kernels
|
||||||
opentelemetry-api
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
|
@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
|||||||
] }
|
] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
ureq = "=2.9"
|
ureq = "=2.9"
|
||||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
pyo3 = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -146,6 +146,7 @@ pub enum Config {
|
|||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Mllama,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
@ -159,6 +160,7 @@ pub enum Config {
|
|||||||
#[serde(rename = "phi-msft")]
|
#[serde(rename = "phi-msft")]
|
||||||
PhiMsft,
|
PhiMsft,
|
||||||
Phi3,
|
Phi3,
|
||||||
|
PhiMoe,
|
||||||
Llama,
|
Llama,
|
||||||
Baichuan,
|
Baichuan,
|
||||||
Paligemma(Paligemma),
|
Paligemma(Paligemma),
|
||||||
|
@ -29,7 +29,7 @@ impl ChatTemplate {
|
|||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
tracing::debug!("Loading template: {:#?}", template_str);
|
tracing::debug!("Loading template: {}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
|
@ -8,9 +8,11 @@ use crate::{
|
|||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
};
|
};
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chat_template::ChatTemplate;
|
use chat_template::ChatTemplate;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use futures::Stream;
|
||||||
use minijinja::ErrorKind;
|
use minijinja::ErrorKind;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -87,7 +89,14 @@ impl Infer {
|
|||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||||
|
),
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
@ -107,9 +116,18 @@ impl Infer {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
let generation_stream = self.backend.schedule(valid_request)?;
|
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||||
|
|
||||||
Ok((permit, input_length, generation_stream))
|
// Wrap generation stream to update the backend health if the stream contains an error
|
||||||
|
let final_stream = stream! {
|
||||||
|
while let Some(response) = generation_stream.next().await {
|
||||||
|
yield response.inspect_err(|_err| {
|
||||||
|
self.backend_health.store(false, Ordering::SeqCst);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((permit, input_length, final_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
@ -278,13 +296,6 @@ impl Infer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Type alias for generation responses
|
|
||||||
pub(crate) type GenerateStreamResponse = (
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32, // input_length
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GeneratedText {
|
pub struct GeneratedText {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
mod queue;
|
|
||||||
mod scheduler;
|
|
||||||
|
|
||||||
pub(crate) use scheduler::BackendV2;
|
|
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,10 @@ mod kserve;
|
|||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
|
mod vertex;
|
||||||
|
|
||||||
|
use crate::infer::{Infer, InferError};
|
||||||
|
use crate::server::prepare_chat_input;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
@ -54,32 +57,6 @@ impl std::str::FromStr for Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
pub(crate) struct GenerateVertexInstance {
|
|
||||||
#[schema(example = "What is Deep Learning?")]
|
|
||||||
pub inputs: String,
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub parameters: Option<GenerateParameters>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum VertexInstance {
|
|
||||||
Generate(GenerateVertexInstance),
|
|
||||||
Chat(ChatRequest),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
|
||||||
pub(crate) struct VertexRequest {
|
|
||||||
#[serde(rename = "instances")]
|
|
||||||
pub instances: Vec<VertexInstance>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
|
||||||
pub(crate) struct VertexResponse {
|
|
||||||
pub predictions: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
@ -174,6 +151,7 @@ impl HubProcessorConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||||
@ -230,6 +208,7 @@ pub struct Info {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
/// Generate best_of sequences and return the one if the highest token logprobs.
|
/// Generate best_of sequences and return the one if the highest token logprobs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -684,6 +663,7 @@ pub(crate) struct ChatCompletionChunk {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, ToSchema)]
|
#[derive(Clone, Serialize, ToSchema)]
|
||||||
@ -732,6 +712,7 @@ impl ChatCompletionChunk {
|
|||||||
created: u64,
|
created: u64,
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
|
usage: Option<Usage>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let delta = match (delta, tool_calls) {
|
let delta = match (delta, tool_calls) {
|
||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
@ -766,11 +747,13 @@ impl ChatCompletionChunk {
|
|||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
|
usage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
|
||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
@ -880,6 +863,93 @@ pub(crate) struct ChatRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatRequest {
|
||||||
|
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
|
||||||
|
let ChatRequest {
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
messages,
|
||||||
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
|
response_format,
|
||||||
|
guideline,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_p,
|
||||||
|
top_logprobs,
|
||||||
|
..
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
|
infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
&tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop,
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: top_logprobs,
|
||||||
|
grammar,
|
||||||
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
using_tools,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
struct StreamOptions {
|
||||||
|
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
#[schema(example = "true")]
|
||||||
|
include_usage: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_tool_prompt() -> String {
|
pub fn default_tool_prompt() -> String {
|
||||||
@ -969,6 +1039,7 @@ pub(crate) struct FunctionDefinition {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
pub(crate) struct Tool {
|
pub(crate) struct Tool {
|
||||||
// The type of the tool. Currently, only 'function' is supported.
|
// The type of the tool. Currently, only 'function' is supported.
|
||||||
#[schema(example = "function")]
|
#[schema(example = "function")]
|
||||||
@ -1472,6 +1543,27 @@ mod tests {
|
|||||||
let textmsg: TextMessage = message.into();
|
let textmsg: TextMessage = message.into();
|
||||||
assert_eq!(textmsg.content, "Whats in this image?");
|
assert_eq!(textmsg.content, "Whats in this image?");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_stream_options() {
|
||||||
|
let json = json!({
|
||||||
|
"model": "",
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
request.stream_options,
|
||||||
|
Some(StreamOptions {
|
||||||
|
include_usage: true
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn openai_output() {
|
fn openai_output() {
|
||||||
let message = OutputMessage::ChatMessage(TextMessage {
|
let message = OutputMessage::ChatMessage(TextMessage {
|
||||||
|
@ -8,20 +8,20 @@ use crate::kserve::{
|
|||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
use crate::vertex::vertex_compatibility;
|
||||||
|
use crate::ChatTokenizeResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
||||||
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token,
|
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
||||||
TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
VertexResponse,
|
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
@ -149,63 +149,11 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
|||||||
)]
|
)]
|
||||||
async fn get_chat_tokenize(
|
async fn get_chat_tokenize(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
let ChatRequest {
|
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
..
|
|
||||||
} = req;
|
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
|
||||||
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let generate_request = GenerateRequest {
|
|
||||||
inputs,
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample: true,
|
|
||||||
max_new_tokens: max_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop: stop.unwrap_or_default(),
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: false,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: None,
|
|
||||||
grammar: _grammar,
|
|
||||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let input = generate_request.inputs.clone();
|
let input = generate_request.inputs.clone();
|
||||||
let encoding = infer.tokenize(generate_request).await?;
|
let encoding = infer.tokenize(generate_request).await?;
|
||||||
if let Some(encoding) = encoding {
|
if let Some(encoding) = encoding {
|
||||||
@ -1162,76 +1110,20 @@ async fn chat_completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
|
||||||
logprobs,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
presence_penalty,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
stream,
|
||||||
tools,
|
stream_options,
|
||||||
tool_choice,
|
logprobs,
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
..
|
..
|
||||||
} = req;
|
} = chat.clone();
|
||||||
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
|
chat.try_into_generate(&infer)?;
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
|
||||||
let tool_prompt = tool_prompt
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.unwrap_or_else(default_tool_prompt);
|
|
||||||
let stop = stop.unwrap_or_default();
|
|
||||||
// enable greedy only when temperature is 0
|
|
||||||
let (do_sample, temperature) = match temperature {
|
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
|
||||||
other => (true, other),
|
|
||||||
};
|
|
||||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
|
||||||
let generate_request = GenerateRequest {
|
|
||||||
inputs: inputs.to_string(),
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty: req.frequency_penalty,
|
|
||||||
top_k: None,
|
|
||||||
top_p: req.top_p,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop,
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: req.top_logprobs,
|
|
||||||
grammar,
|
|
||||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// static values that will be returned in all cases
|
// static values that will be returned in all cases
|
||||||
let model_id = info.model_id.clone();
|
let model_id = info.model_id.clone();
|
||||||
@ -1265,6 +1157,28 @@ async fn chat_completions(
|
|||||||
(content, None)
|
(content, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (usage, finish_reason) = match stream_token.details {
|
||||||
|
Some(details) => {
|
||||||
|
let usage = if stream_options
|
||||||
|
.as_ref()
|
||||||
|
.map(|s| s.include_usage)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let completion_tokens = details.generated_tokens;
|
||||||
|
let prompt_tokens = details.input_length;
|
||||||
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
|
Some(Usage {
|
||||||
|
completion_tokens,
|
||||||
|
prompt_tokens,
|
||||||
|
total_tokens,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(usage, Some(details.finish_reason.format(true)))
|
||||||
|
}
|
||||||
|
None => (None, None),
|
||||||
|
};
|
||||||
event
|
event
|
||||||
.json_data(CompletionType::ChatCompletionChunk(
|
.json_data(CompletionType::ChatCompletionChunk(
|
||||||
ChatCompletionChunk::new(
|
ChatCompletionChunk::new(
|
||||||
@ -1274,7 +1188,8 @@ async fn chat_completions(
|
|||||||
tool_calls,
|
tool_calls,
|
||||||
current_time,
|
current_time,
|
||||||
logprobs,
|
logprobs,
|
||||||
stream_token.details.map(|d| d.finish_reason.format(true)),
|
finish_reason,
|
||||||
|
usage,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|e| {
|
.unwrap_or_else(|e| {
|
||||||
@ -1361,186 +1276,6 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens from Vertex request
|
|
||||||
#[utoipa::path(
|
|
||||||
post,
|
|
||||||
tag = "Text Generation Inference",
|
|
||||||
path = "/vertex",
|
|
||||||
request_body = VertexRequest,
|
|
||||||
responses(
|
|
||||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Model is overloaded"})),
|
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Input validation error"})),
|
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Incomplete generation"})),
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
#[instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
total_time,
|
|
||||||
validation_time,
|
|
||||||
queue_time,
|
|
||||||
inference_time,
|
|
||||||
time_per_token,
|
|
||||||
seed,
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
async fn vertex_compatibility(
|
|
||||||
Extension(infer): Extension<Infer>,
|
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
|
||||||
Json(req): Json<VertexRequest>,
|
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
|
||||||
let span = tracing::Span::current();
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
|
||||||
|
|
||||||
// check that theres at least one instance
|
|
||||||
if req.instances.is_empty() {
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Input validation error".to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare futures for all instances
|
|
||||||
let mut futures = Vec::with_capacity(req.instances.len());
|
|
||||||
|
|
||||||
for instance in req.instances.iter() {
|
|
||||||
let generate_request = match instance {
|
|
||||||
VertexInstance::Generate(instance) => GenerateRequest {
|
|
||||||
inputs: instance.inputs.clone(),
|
|
||||||
add_special_tokens: true,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
do_sample: true,
|
|
||||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
|
||||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: true,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
VertexInstance::Chat(instance) => {
|
|
||||||
let ChatRequest {
|
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
presence_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
top_p,
|
|
||||||
top_logprobs,
|
|
||||||
..
|
|
||||||
} = instance.clone();
|
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
|
||||||
let tool_prompt = tool_prompt
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.unwrap_or_else(default_tool_prompt);
|
|
||||||
let stop = stop.unwrap_or_default();
|
|
||||||
// enable greedy only when temperature is 0
|
|
||||||
let (do_sample, temperature) = match temperature {
|
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
|
||||||
other => (true, other),
|
|
||||||
};
|
|
||||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
) {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(e) => {
|
|
||||||
return Err((
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: format!("Failed to prepare chat input: {}", e),
|
|
||||||
error_type: "Input preparation error".to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
GenerateRequest {
|
|
||||||
inputs: inputs.to_string(),
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
top_k: None,
|
|
||||||
top_p,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop,
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: top_logprobs,
|
|
||||||
grammar,
|
|
||||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let infer_clone = infer.clone();
|
|
||||||
let compute_type_clone = compute_type.clone();
|
|
||||||
let span_clone = span.clone();
|
|
||||||
|
|
||||||
futures.push(async move {
|
|
||||||
generate_internal(
|
|
||||||
Extension(infer_clone),
|
|
||||||
compute_type_clone,
|
|
||||||
Json(generate_request),
|
|
||||||
span_clone,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
|
||||||
.map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Incomplete generation".into(),
|
|
||||||
error_type: "Incomplete generation".into(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
|
||||||
let results = futures::future::join_all(futures).await;
|
|
||||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
|
||||||
let predictions = predictions?;
|
|
||||||
|
|
||||||
let response = VertexResponse { predictions };
|
|
||||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tokenize inputs
|
/// Tokenize inputs
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
@ -1664,6 +1399,7 @@ StreamDetails,
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
Usage,
|
Usage,
|
||||||
|
StreamOptions,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ToolType,
|
ToolType,
|
||||||
Tool,
|
Tool,
|
||||||
@ -2136,9 +1872,12 @@ async fn start(
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||||
// .unwrap();
|
// .unwrap();
|
||||||
let prom_handle = builder
|
// See: https://github.com/metrics-rs/metrics/issues/467#issuecomment-2022755151
|
||||||
.install_recorder()
|
let (recorder, _) = builder
|
||||||
.expect("failed to install metrics recorder");
|
.build()
|
||||||
|
.expect("failed to build prometheus recorder");
|
||||||
|
let prom_handle = recorder.handle();
|
||||||
|
metrics::set_global_recorder(recorder).expect("Failed to set global recorder");
|
||||||
|
|
||||||
// Metrics descriptions
|
// Metrics descriptions
|
||||||
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
|
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
|
||||||
@ -2198,6 +1937,11 @@ async fn start(
|
|||||||
metrics::Unit::Count,
|
metrics::Unit::Count,
|
||||||
"Maximum tokens for the current batch"
|
"Maximum tokens for the current batch"
|
||||||
);
|
);
|
||||||
|
metrics::describe_gauge!(
|
||||||
|
"tgi_batch_total_tokens",
|
||||||
|
metrics::Unit::Count,
|
||||||
|
"Maximum amount of tokens in total."
|
||||||
|
);
|
||||||
metrics::describe_histogram!(
|
metrics::describe_histogram!(
|
||||||
"tgi_request_max_new_tokens",
|
"tgi_request_max_new_tokens",
|
||||||
metrics::Unit::Count,
|
metrics::Unit::Count,
|
||||||
@ -2290,7 +2034,8 @@ async fn start(
|
|||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
use crate::VertexInstance;
|
use crate::vertex::__path_vertex_compatibility;
|
||||||
|
use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
@ -2609,7 +2354,7 @@ pub enum WebServerError {
|
|||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
pub(crate) fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
response_format: Option<GrammarType>,
|
response_format: Option<GrammarType>,
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
|
@ -567,6 +567,7 @@ fn image_tokens(
|
|||||||
use HubPreprocessorConfig::*;
|
use HubPreprocessorConfig::*;
|
||||||
match config {
|
match config {
|
||||||
Idefics => "<image>".to_string(),
|
Idefics => "<image>".to_string(),
|
||||||
|
Mllama => "<|image|>".to_string(),
|
||||||
Idefics2(config) => {
|
Idefics2(config) => {
|
||||||
const FAKE: &str = "<fake_token_around_image>";
|
const FAKE: &str = "<fake_token_around_image>";
|
||||||
const IMAGE: &str = "<image>";
|
const IMAGE: &str = "<image>";
|
||||||
@ -618,7 +619,7 @@ fn prepare_input(
|
|||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
360
router/src/vertex.rs
Normal file
360
router/src/vertex.rs
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
use crate::infer::Infer;
|
||||||
|
use crate::server::{generate_internal, ComputeType};
|
||||||
|
use crate::{
|
||||||
|
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
||||||
|
StreamOptions, Tool, ToolChoice,
|
||||||
|
};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::Json;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct GenerateVertexInstance {
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
|
pub inputs: String,
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexChat {
|
||||||
|
messages: Vec<Message>,
|
||||||
|
// Messages is ignored there.
|
||||||
|
#[serde(default)]
|
||||||
|
parameters: VertexParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexParameters {
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||||
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||||
|
/// result in a ban or exclusive selection of the relevant token.
|
||||||
|
#[serde(default)]
|
||||||
|
pub logit_bias: Option<Vec<f32>>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
|
/// output token returned in the content of message.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "5")]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "32")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "2")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
|
/// increasing the model's likelihood to talk about new topics
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.1)]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(default = "bool::default")]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
|
/// lower values like 0.2 will make it more focused and deterministic.
|
||||||
|
///
|
||||||
|
/// We generally recommend altering this or `top_p` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||||
|
/// functions the model may generate JSON inputs for.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A prompt to be appended before the tools
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
nullable = true,
|
||||||
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
|
)]
|
||||||
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tool_choice: ToolChoice,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
|
|
||||||
|
/// A guideline to be used in the chat_template
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VertexChat> for ChatRequest {
|
||||||
|
fn from(val: VertexChat) -> Self {
|
||||||
|
Self {
|
||||||
|
messages: val.messages,
|
||||||
|
frequency_penalty: val.parameters.frequency_penalty,
|
||||||
|
guideline: val.parameters.guideline,
|
||||||
|
logit_bias: val.parameters.logit_bias,
|
||||||
|
logprobs: val.parameters.logprobs,
|
||||||
|
max_tokens: val.parameters.max_tokens,
|
||||||
|
model: val.parameters.model,
|
||||||
|
n: val.parameters.n,
|
||||||
|
presence_penalty: val.parameters.presence_penalty,
|
||||||
|
response_format: val.parameters.response_format,
|
||||||
|
seed: val.parameters.seed,
|
||||||
|
stop: val.parameters.stop,
|
||||||
|
stream_options: val.parameters.stream_options,
|
||||||
|
stream: val.parameters.stream,
|
||||||
|
temperature: val.parameters.temperature,
|
||||||
|
tool_choice: val.parameters.tool_choice,
|
||||||
|
tool_prompt: val.parameters.tool_prompt,
|
||||||
|
tools: val.parameters.tools,
|
||||||
|
top_logprobs: val.parameters.top_logprobs,
|
||||||
|
top_p: val.parameters.top_p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum VertexInstance {
|
||||||
|
Generate(GenerateVertexInstance),
|
||||||
|
Chat(VertexChat),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexRequest {
|
||||||
|
#[serde(rename = "instances")]
|
||||||
|
pub instances: Vec<VertexInstance>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct VertexResponse {
|
||||||
|
pub predictions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Vertex request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/vertex",
|
||||||
|
request_body = VertexRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub(crate) async fn vertex_compatibility(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(req): Json<VertexRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
|
// check that theres at least one instance
|
||||||
|
if req.instances.is_empty() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare futures for all instances
|
||||||
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
|
|
||||||
|
for instance in req.instances.into_iter() {
|
||||||
|
let generate_request = match instance {
|
||||||
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
|
inputs: instance.inputs.clone(),
|
||||||
|
add_special_tokens: true,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
VertexInstance::Chat(instance) => {
|
||||||
|
let chat_request: ChatRequest = instance.into();
|
||||||
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||||||
|
chat_request.try_into_generate(&infer)?;
|
||||||
|
generate_request
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
futures.push(async move {
|
||||||
|
generate_internal(
|
||||||
|
Extension(infer_clone),
|
||||||
|
compute_type_clone,
|
||||||
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||||
|
let results = futures::future::join_all(futures).await;
|
||||||
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||||
|
let predictions = predictions?;
|
||||||
|
|
||||||
|
let response = VertexResponse { predictions };
|
||||||
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{Message, MessageContent};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vertex_deserialization() {
|
||||||
|
let string = serde_json::json!({
|
||||||
|
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 128,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
|
||||||
|
let string = serde_json::json!({
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
});
|
||||||
|
|
||||||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
|
||||||
|
let string = serde_json::json!({
|
||||||
|
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 128,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
});
|
||||||
|
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
assert_eq!(
|
||||||
|
request,
|
||||||
|
VertexRequest {
|
||||||
|
instances: vec![VertexInstance::Chat(VertexChat {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||||||
|
name: None,
|
||||||
|
},],
|
||||||
|
parameters: VertexParameters {
|
||||||
|
max_tokens: Some(128),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})]
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
flash_att_v2_commit_cuda := v2.6.1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
pip install -U packaging wheel
|
pip install -U packaging wheel
|
||||||
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
|||||||
build-flash-attention-v2-rocm:
|
build-flash-attention-v2-rocm:
|
||||||
if [ ! -d 'flash-attention-v2' ]; then \
|
if [ ! -d 'flash-attention-v2' ]; then \
|
||||||
pip install -U packaging ninja --no-cache-dir && \
|
pip install -U packaging ninja --no-cache-dir && \
|
||||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
|
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
|
||||||
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
|
||||||
fi
|
fi
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
|||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||||
fi
|
fi
|
||||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
@ -1,5 +1,17 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
import torch
|
||||||
|
|
||||||
|
extra_cuda_cflags = []
|
||||||
|
extra_cflags = []
|
||||||
|
if torch.version.hip:
|
||||||
|
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||||
|
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||||
|
|
||||||
|
extra_compile_args = {
|
||||||
|
"cxx": extra_cflags,
|
||||||
|
"nvcc": extra_cuda_cflags,
|
||||||
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="exllama_kernels",
|
name="exllama_kernels",
|
||||||
@ -13,6 +25,7 @@ setup(
|
|||||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||||
"exllama_kernels/cuda_func/q4_matrix.cu",
|
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||||
],
|
],
|
||||||
|
extra_compile_args=extra_compile_args,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
extra_cuda_cflags = ["-lineinfo", "-O3"]
|
||||||
|
extra_cflags = []
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
|
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||||
|
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
|
||||||
|
|
||||||
extra_compile_args = {
|
extra_compile_args = {
|
||||||
|
"cxx": extra_cflags,
|
||||||
"nvcc": extra_cuda_cflags,
|
"nvcc": extra_cuda_cflags,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2652
server/poetry.lock
generated
2652
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
|
|||||||
opentelemetry-exporter-otlp = "^1.25.0"
|
opentelemetry-exporter-otlp = "^1.25.0"
|
||||||
opentelemetry-instrumentation-grpc = "^0.46b0"
|
opentelemetry-instrumentation-grpc = "^0.46b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.2"
|
||||||
tokenizers = "^0.19.1"
|
tokenizers = "^0.20"
|
||||||
huggingface-hub = "^0.23"
|
huggingface-hub = "^0.23"
|
||||||
transformers = "^4.43"
|
transformers = "^4.45"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
@ -46,6 +46,12 @@ marlin-kernels = [
|
|||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
|
moe-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
|
]
|
||||||
rich = "^13.7.1"
|
rich = "^13.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
@ -53,6 +59,7 @@ torch = ["torch"]
|
|||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
marlin = ["marlin-kernels"]
|
marlin = ["marlin-kernels"]
|
||||||
|
moe = ["moe-kernels"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
quantize = ["texttable", "datasets", "accelerate"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
|||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
|||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
|||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -30,6 +30,10 @@ class Dtype(str, Enum):
|
|||||||
bloat16 = "bfloat16"
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheDtype(str, Enum):
|
||||||
|
fp8_e5m2 = "fp8_e5m2"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -38,6 +42,7 @@ def serve(
|
|||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
|
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -97,6 +102,7 @@ def serve(
|
|||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
|
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||||
if dtype is not None and quantize not in {
|
if dtype is not None and quantize not in {
|
||||||
None,
|
None,
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
@ -114,6 +120,7 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
@ -1,29 +1,47 @@
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
from .common import Seqlen
|
from .common import Seqlen
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import (
|
from .cuda import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .ipex import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
||||||
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
|
from .kv_cache import KVCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
"reshape_and_cache",
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
]
|
]
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -65,5 +66,7 @@ else:
|
|||||||
max_k: int
|
max_k: int
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
return self
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
raise NotImplementedError("Not implemented seqlen for paged")
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
@ -287,16 +287,14 @@ elif V2:
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
value_cache: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlens,
|
softmax_scale: float,
|
||||||
max_s,
|
window_size_left: int = -1,
|
||||||
softmax_scale,
|
causal: bool = True,
|
||||||
window_size_left=-1,
|
|
||||||
causal=None,
|
|
||||||
softcap=None,
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
@ -338,16 +336,30 @@ else:
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
False,
|
False,
|
||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# Prefill in the cache with every kind of attention, unless we
|
||||||
|
# have a configuration that requires flash-attention v1, which
|
||||||
|
# does not support block tables.
|
||||||
|
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
"reshape_and_cache",
|
||||||
|
]
|
||||||
|
@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
query_dtype: str = "float16",
|
dtype: torch.dtype,
|
||||||
|
window_left: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
|
|||||||
num_qo_heads=num_heads,
|
num_qo_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
q_data_type=query_dtype,
|
q_data_type=dtype,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
window_left=window_left,
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -119,7 +121,8 @@ def use_prefill_state(
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
query_dtype: str = "float16",
|
dtype: torch.dtype,
|
||||||
|
window_left: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active flashinfer prefill state to the given
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
@ -135,7 +138,8 @@ def use_prefill_state(
|
|||||||
num_qo_heads=num_heads,
|
num_qo_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
q_data_type=query_dtype,
|
q_data_type=dtype,
|
||||||
|
window_left=window_left,
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -152,11 +156,13 @@ def create_decode_state(
|
|||||||
):
|
):
|
||||||
"""Create a decode state."""
|
"""Create a decode state."""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -175,6 +181,7 @@ def create_decode_state_cuda_graphs(
|
|||||||
therefore stored as part of the state.
|
therefore stored as part of the state.
|
||||||
"""
|
"""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
@ -182,7 +189,8 @@ def create_decode_state_cuda_graphs(
|
|||||||
paged_kv_indices_buffer=block_tables,
|
paged_kv_indices_buffer=block_tables,
|
||||||
paged_kv_indptr_buffer=block_tables_ptr,
|
paged_kv_indptr_buffer=block_tables_ptr,
|
||||||
paged_kv_last_page_len_buffer=last_page_len,
|
paged_kv_last_page_len_buffer=last_page_len,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -196,7 +204,8 @@ def use_decode_state(
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
query_dtype: str = "float16",
|
dtype: torch.dtype,
|
||||||
|
window_left: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active flashinfer decoding state to the given
|
Context manager to set the active flashinfer decoding state to the given
|
||||||
@ -231,7 +240,9 @@ def use_decode_state(
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
q_data_type=query_dtype,
|
data_type=dtype,
|
||||||
|
q_data_type=dtype,
|
||||||
|
window_left=window_left,
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user