mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge branch 'main' into mi300-compat
This commit is contained in:
commit
8ec3b1a7a7
93
.github/workflows/build.yaml
vendored
93
.github/workflows/build.yaml
vendored
@ -274,12 +274,105 @@ jobs:
|
||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||
|
||||
build-and-push-image-intel:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image # Wait for the main docker image to be built
|
||||
- integration-tests # Wait for the main integration-tests
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
# This is used to complete the identity challenge
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2.0.0
|
||||
with:
|
||||
install: true
|
||||
- name: Inject slug/short variables
|
||||
uses: rlespinasse/github-slug-action@v4.4.1
|
||||
- name: Tailscale
|
||||
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Login to internal Container Registry
|
||||
uses: docker/login-action@v2.1.0
|
||||
with:
|
||||
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
||||
registry: registry.internal.huggingface.tech
|
||||
- name: Login to Azure Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v2.1.0
|
||||
with:
|
||||
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||
# If pull request
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
id: meta-pr
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
tags: |
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||
# If main, release or tag
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4.3.0
|
||||
with:
|
||||
flavor: |
|
||||
latest=false
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
ghcr.io/huggingface/text-generation-inference
|
||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||
tags: |
|
||||
type=semver,pattern={{version}}-intel
|
||||
type=semver,pattern={{major}}.{{minor}}-intel
|
||||
type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||
- name: Build and push Docker image
|
||||
id: build-and-push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile_intel
|
||||
push: true
|
||||
platforms: 'linux/amd64'
|
||||
build-args: |
|
||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
||||
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image
|
||||
- build-and-push-image-rocm
|
||||
- build-and-push-image-intel
|
||||
- integration-tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -11,3 +11,6 @@ server/exllama_kernels/exllama_kernels/hip_func/
|
||||
*_hip.cuh
|
||||
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||
|
||||
data/
|
||||
load_tests/*.json
|
||||
|
@ -9,7 +9,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -39,7 +39,7 @@ RUN cargo build --release
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
|
||||
|
||||
ARG PYTORCH_VERSION=2.1.1
|
||||
ARG PYTORCH_VERSION=2.3.0
|
||||
ARG PYTHON_VERSION=3.10
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.1
|
||||
@ -149,6 +149,8 @@ FROM kernel-builder as vllm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# Build specific version of vllm
|
||||
@ -210,7 +212,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install vllm/flash-attention dependencies
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
@ -246,6 +248,7 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
FROM base
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
105
Dockerfile_intel
Normal file
105
Dockerfile_intel
Normal file
@ -0,0 +1,105 @@
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
|
||||
|
||||
USER root
|
||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
|
||||
WORKDIR /usr/src
|
||||
# Build pytorch and ipex
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b xpu_main origin/xpu-main
|
||||
RUN git clone https://github.com/pytorch/pytorch.git && cd pytorch && git checkout 209f2fa8ff86652f67d75c2f19bf9cb9942fd018 && git apply /usr/src/intel-extension-for-pytorch/torch_patches/00*.patch
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_cuda.txt && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||
ENV DIAGUTIL_PATH=/opt/intel/oneapi/compiler/latest/etc/compiler/sys_check/sys_check.sh
|
||||
ENV CCL_CONFIGURATION=cpu_gpu_dpcpp
|
||||
ENV MANPATH=/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/compiler/latest/share/man
|
||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||
ENV CMPLR_ROOT=/opt/intel/oneapi/compiler/latest
|
||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||
ENV OCL_ICD_FILENAMES=libintelocl_emu.so:libalteracl.so:/opt/intel/oneapi/compiler/latest/lib/libintelocl.so
|
||||
ENV CLASSPATH=/opt/intel/oneapi/mpi/latest/share/java/mpi.jar:/opt/intel/oneapi/mpi/latest/share/java/mpi.jar
|
||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||
ENV MKLROOT=/opt/intel/oneapi/mkl/latest
|
||||
ENV NLSPATH=/opt/intel/oneapi/mkl/latest/share/locale/%l_%t/%N:/opt/intel/oneapi/compiler/latest/lib/locale/%l_%t/%N
|
||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
|
||||
|
||||
RUN pip uninstall -y torch && cd pytorch && git submodule update --init --recursive && python setup.py install
|
||||
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=ON BUILD_WITH_CPU=ON USE_XETLA=ON python setup.py install
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
@ -64,6 +64,7 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
||||
- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
|
||||
- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
|
||||
- [Gaudi](https://github.com/huggingface/tgi-gaudi)
|
||||
- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving)
|
||||
|
||||
|
||||
## Get Started
|
||||
|
@ -80,6 +80,7 @@ class Client:
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -119,6 +120,8 @@ class Client:
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_prompt (`str`):
|
||||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
@ -139,6 +142,7 @@ class Client:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
@ -466,6 +470,7 @@ class AsyncClient:
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
@ -505,6 +510,8 @@ class AsyncClient:
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_prompt (`str`):
|
||||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
@ -525,6 +532,7 @@ class AsyncClient:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
|
@ -159,6 +159,8 @@ class ChatRequest(BaseModel):
|
||||
top_p: Optional[float] = None
|
||||
# List of tools to be used
|
||||
tools: Optional[List[Tool]] = None
|
||||
# A prompt to be appended before the tools
|
||||
tool_prompt: Optional[str] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
|
||||
|
@ -25,6 +25,10 @@
|
||||
title: Non-core Model Serving
|
||||
- local: basic_tutorials/safety
|
||||
title: Safety
|
||||
- local: basic_tutorials/using_guidance
|
||||
title: Using Guidance, JSON, tools
|
||||
- local: basic_tutorials/visual_language_models
|
||||
title: Visual Language Models
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- local: conceptual/streaming
|
||||
@ -42,5 +46,6 @@
|
||||
- local: conceptual/speculation
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: Guidance, JSON, tools (using outlines)
|
||||
title: How Guidance Works (via outlines)
|
||||
|
||||
title: Conceptual Guides
|
||||
|
@ -162,7 +162,7 @@ Options:
|
||||
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
|
||||
|
||||
[env: WAITING_SERVED_RATIO=]
|
||||
[default: 1.2]
|
||||
[default: 0.3]
|
||||
|
||||
```
|
||||
## MAX_BATCH_PREFILL_TOKENS
|
||||
|
419
docs/source/basic_tutorials/using_guidance.md
Normal file
419
docs/source/basic_tutorials/using_guidance.md
Normal file
@ -0,0 +1,419 @@
|
||||
# Guidance
|
||||
|
||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
|
||||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._
|
||||
|
||||
## How it works
|
||||
|
||||
TGI leverages the [outlines](https://github.com/outlines-dev/outlines) library to efficiently parse and compile the grammatical structures and tools specified by users. This integration transforms the defined grammars into an intermediate representation that acts as a framework to guide and constrain content generation, ensuring that outputs adhere to the specified grammatical rules.
|
||||
|
||||
If you are interested in the technical details on how outlines is used in TGI, you can check out the [conceptual guidance documentation](../conceptual/guidance).
|
||||
|
||||
## Table of Contents 📚
|
||||
|
||||
### Grammar and Constraints
|
||||
|
||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||
- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
|
||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||
|
||||
### Tools and Functions
|
||||
|
||||
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
## Grammar and Constraints 🛣️
|
||||
|
||||
### The Grammar Parameter
|
||||
|
||||
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the LLM.
|
||||
|
||||
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||
|
||||
```json
|
||||
curl localhost:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||
|
||||
```
|
||||
|
||||
A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The LLM will then generate a response that conforms to the specified grammar.
|
||||
|
||||
> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
|
||||
### Constrain with Pydantic
|
||||
|
||||
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from pydantic import BaseModel, conint
|
||||
from typing import List
|
||||
|
||||
class Animals(BaseModel):
|
||||
location: str
|
||||
activity: str
|
||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||
animals: List[str]
|
||||
|
||||
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": Animals.schema()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||
|
||||
```
|
||||
|
||||
### JSON Schema Integration
|
||||
|
||||
If Pydantic's not your style, go raw with direct JSON Schema integration. This is similar to the first example but with programmatic control.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
json_schema = {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"max_new_tokens": 200,
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": json_schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||
|
||||
```
|
||||
|
||||
### Using the client
|
||||
|
||||
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=1,
|
||||
grammar={
|
||||
"type": GrammarType.Regex,
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.generated_text)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# 118.8.0.84
|
||||
|
||||
```
|
||||
|
||||
## Tools and Functions 🛠️
|
||||
|
||||
### The Tools Parameter
|
||||
|
||||
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||
|
||||
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||
|
||||
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||
|
||||
```json
|
||||
curl localhost:3000/v1/chat/completions \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in New York?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "get_current_weather"
|
||||
}'
|
||||
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||
```
|
||||
|
||||
### Text Generation Inference Client
|
||||
|
||||
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.choices[0].message.tool_calls)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Tools used in example above</summary>
|
||||
|
||||
```python
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### OpenAI integration
|
||||
|
||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize the client, pointing it to one of the available models
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3000/v1",
|
||||
api_key="_",
|
||||
)
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto", # tool selected by model
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
|
||||
called = chat_completion.choices[0].message.tool_calls
|
||||
print(called)
|
||||
# {
|
||||
# "id": 0,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "description": None,
|
||||
# "name": "tools",
|
||||
# "parameters": {
|
||||
# "format": "celsius",
|
||||
# "location": "San Francisco, CA",
|
||||
# "num_days": 3,
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
```
|
170
docs/source/basic_tutorials/visual_language_models.md
Normal file
170
docs/source/basic_tutorials/visual_language_models.md
Normal file
@ -0,0 +1,170 @@
|
||||
# Vision Language Model Inference in TGI
|
||||
|
||||
Visual Language Model (VLM) are models that consume both image and text inputs to generate text.
|
||||
|
||||
VLM's are trained on a combination of image and text data and can handle a wide range of tasks, such as image captioning, visual question answering, and visual dialog.
|
||||
|
||||
> What distinguishes VLMs from other text and image models is their ability to handle long context and generate text that is coherent and relevant to the image even after multiple turns or in some cases, multiple images.
|
||||
|
||||
Below are couple of common use cases for vision language models:
|
||||
|
||||
- **Image Captioning**: Given an image, generate a caption that describes the image.
|
||||
- **Visual Question Answering (VQA)**: Given an image and a question about the image, generate an answer to the question.
|
||||
- **Mulimodal Dialog**: Generate response to multiple turns of images and conversations.
|
||||
- **Image Information Retrieval**: Given an image, retrieve information from the image.
|
||||
|
||||
## How to Use a Vision Language Model?
|
||||
|
||||
### Hugging Face Hub Python Library
|
||||
|
||||
To infer with vision language models through Python, you can use the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The `InferenceClient` class provides a simple way to interact with the [Inference API](https://huggingface.co/docs/api-inference/index). Images can be passed as URLs or base64-encoded strings. The `InferenceClient` will automatically detect the image format.
|
||||
|
||||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient("http://127.0.0.1:3000")
|
||||
image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
prompt = f"What is this a picture of?\n\n"
|
||||
for token in client.text_generation(prompt, max_new_tokens=16, stream=True):
|
||||
print(token)
|
||||
|
||||
# This is a picture of an anthropomorphic rabbit in a space suit.
|
||||
```
|
||||
|
||||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
import base64
|
||||
import requests
|
||||
import io
|
||||
|
||||
client = InferenceClient("http://127.0.0.1:3000")
|
||||
|
||||
# read image from local file
|
||||
image_path = "rabbit.png"
|
||||
with open(image_path, "rb") as f:
|
||||
image = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
image = f"data:image/png;base64,{image}"
|
||||
prompt = f"What is this a picture of?\n\n"
|
||||
|
||||
for token in client.text_generation(prompt, max_new_tokens=10, stream=True):
|
||||
print(token)
|
||||
|
||||
# This is a picture of an anthropomorphic rabbit in a space suit.
|
||||
```
|
||||
|
||||
If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text.
|
||||
|
||||
### Inference Through Sending `cURL` Requests
|
||||
|
||||
To use the `generate_stream` endpoint with curl, you can add the `-N` flag. This flag disables curl default buffering and shows data as it arrives from the server.
|
||||
|
||||
```bash
|
||||
curl -N 127.0.0.1:3000/generate_stream \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is this a picture of?\n\n","parameters":{"max_new_tokens":16, "seed": 42}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
|
||||
# ...
|
||||
# data:{"index":16,"token":{"id":28723,"text":".","logprob":-0.6196289,"special":false},"generated_text":"This is a picture of an anthropomorphic rabbit in a space suit.","details":null}
|
||||
```
|
||||
|
||||
### Inference Through JavaScript
|
||||
|
||||
First, we need to install the `@huggingface/inference` library.
|
||||
|
||||
```bash
|
||||
npm install @huggingface/inference
|
||||
```
|
||||
|
||||
If you're using the free Inference API, you can use [Huggingface.js](https://huggingface.co/docs/huggingface.js/inference/README)'s `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint` class to easily interact with the Inference API.
|
||||
|
||||
We can create a `HfInferenceEndpoint` providing our endpoint URL and We can create a `HfInferenceEndpoint` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens).
|
||||
|
||||
```js
|
||||
import { HfInferenceEndpoint } from "@huggingface/inference";
|
||||
|
||||
const hf = new HfInferenceEndpoint("http://127.0.0.1:3000", "HF_TOKEN");
|
||||
|
||||
const prompt =
|
||||
"What is this a picture of?\n\n";
|
||||
|
||||
const stream = hf.textGenerationStream({
|
||||
inputs: prompt,
|
||||
parameters: { max_new_tokens: 16, seed: 42 },
|
||||
});
|
||||
for await (const r of stream) {
|
||||
// yield the generated token
|
||||
process.stdout.write(r.token.text);
|
||||
}
|
||||
|
||||
// This is a picture of an anthropomorphic rabbit in a space suit.
|
||||
```
|
||||
|
||||
## Combining Vision Language Models with Other Features
|
||||
|
||||
VLMs in TGI have several advantages, for example these models can be used in tandem with other features for more complex tasks. For example, you can use VLMs with [Guided Generation](/docs/conceptual/guided-generation) to generate specific JSON data from an image.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
width="400"
|
||||
/>
|
||||
</div>
|
||||
|
||||
For example we can extract information from the rabbit image and generate a JSON object with the location, activity, number of animals seen, and the animals seen. That would look like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"activity": "Standing",
|
||||
"animals": ["Rabbit"],
|
||||
"animals_seen": 1,
|
||||
"location": "Rocky surface with mountains in the background and a red light on the rabbit's chest"
|
||||
}
|
||||
```
|
||||
|
||||
All we need to do is provide a JSON schema to the VLM model and it will generate the JSON object for us.
|
||||
|
||||
```bash
|
||||
curl localhost:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs":"What is this a picture of?\n\n",
|
||||
"parameters": {
|
||||
"max_new_tokens": 100,
|
||||
"seed": 42,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
|
||||
# {
|
||||
# "generated_text": "{ \"activity\": \"Standing\", \"animals\": [ \"Rabbit\" ], \"animals_seen\": 1, \"location\": \"Rocky surface with mountains in the background and a red light on the rabbit's chest\" }"
|
||||
# }
|
||||
```
|
||||
|
||||
Want to learn more about how Vision Language Models work? Check out the [awesome blog post on the topic](https://huggingface.co/blog/vlms).
|
@ -1,421 +1,86 @@
|
||||
# Guidance
|
||||
|
||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
|
||||
## What is Guidance?
|
||||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format.
|
||||
|
||||
> The Grammar guidance support is currently only available in the TGI API due to lack of support in Open AI API.
|
||||
## How is it used?
|
||||
|
||||
## Quick Start
|
||||
Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
||||
|
||||
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
|
||||
Technically, guidance can be used to generate:
|
||||
|
||||
If you're not up to date, grab the latest version and let's get started!
|
||||
- a specific JSON object
|
||||
- a function signature
|
||||
- typed output like a list of integers
|
||||
|
||||
## Table of Contents 📚
|
||||
However these use cases can span a wide range of applications, such as:
|
||||
|
||||
### Grammar and Constraints
|
||||
- extracting structured data from unstructured text
|
||||
- summarizing text into a specific format
|
||||
- limit output to specific classes of words (act as a LLM powered classifier)
|
||||
- generate the input to specific APIs or services
|
||||
- provide reliable and consistent output for downstream tasks
|
||||
- extract data from multimodal inputs
|
||||
|
||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||
- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
|
||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||
## How it works?
|
||||
|
||||
### Tools and Functions
|
||||
Diving into the details, guidance is enabled by including a grammar with a generation request that is compiled, and used to modify the chosen tokens.
|
||||
|
||||
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
This process can be broken down into the following steps:
|
||||
|
||||
## Grammar and Constraints 🛣️
|
||||
1. A request is sent to the backend, it is processed and placed in batch. Processing includes compiling the grammar into a finite state machine and a grammar state.
|
||||
|
||||
### The Grammar Parameter
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch.gif"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch-dark.gif"
|
||||
/>
|
||||
</div>
|
||||
|
||||
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
|
||||
2. The model does a forward pass over the batch. This returns probabilities for each token in the vocabulary for each request in the batch.
|
||||
|
||||
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||
3. The process of choosing one of those tokens is called `sampling`. The model samples from the distribution of probabilities to choose the next token. In TGI all of the steps before sampling are called `processor`. Grammars are applied as a processor that masks out tokens that are not allowed by the grammar.
|
||||
|
||||
```json
|
||||
curl localhost:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask.gif"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask-dark.gif"
|
||||
/>
|
||||
</div>
|
||||
|
||||
```
|
||||
4. The grammar mask is applied and the model samples from the remaining tokens. Once a token is chosen, we update the grammar state with the new token, to prepare it for the next pass.
|
||||
|
||||
A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits.gif"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits-dark.gif"
|
||||
/>
|
||||
</div>
|
||||
|
||||
> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
## How to use Guidance?
|
||||
|
||||
### Constrain with Pydantic
|
||||
There are two main ways to use guidance; you can either use the `/generate` endpoint with a grammar or use the `/chat/completion` endpoint with tools.
|
||||
|
||||
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
|
||||
Under the hood tools are a special case of grammars that allows the model to choose one or none of the provided tools.
|
||||
|
||||
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||
Please refer to [using guidance](../basic_tutorial/using_guidance) for more examples and details on how to use guidance in Python, JavaScript, and cURL.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from pydantic import BaseModel, conint
|
||||
from typing import List
|
||||
### Getting the most out of guidance
|
||||
|
||||
class Animals(BaseModel):
|
||||
location: str
|
||||
activity: str
|
||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||
animals: List[str]
|
||||
Depending on how you are using guidance, you may want to make use of different features. Here are some tips to get the most out of guidance:
|
||||
|
||||
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": Animals.schema()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||
|
||||
```
|
||||
|
||||
### JSON Schema Integration
|
||||
|
||||
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
json_schema = {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"max_new_tokens": 200,
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": json_schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||
|
||||
```
|
||||
|
||||
### Using the client
|
||||
|
||||
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=1,
|
||||
grammar={
|
||||
"type": GrammarType.Regex,
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.generated_text)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# 118.8.0.84
|
||||
|
||||
```
|
||||
|
||||
## Tools and Functions 🛠️
|
||||
|
||||
### The Tools Parameter
|
||||
|
||||
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||
|
||||
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
|
||||
|
||||
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||
|
||||
```json
|
||||
curl localhost:3000/v1/chat/completions \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in New York?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "get_current_weather"
|
||||
}'
|
||||
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Tools used in example below</summary>
|
||||
|
||||
```python
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Text Generation Inference Client
|
||||
|
||||
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.choices[0].message.tool_calls)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||
|
||||
```
|
||||
|
||||
### OpenAI integration
|
||||
|
||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize the client, pointing it to one of the available models
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3000/v1",
|
||||
api_key="_",
|
||||
)
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto", # tool selected by model
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
|
||||
called = chat_completion.choices[0].message.tool_calls
|
||||
print(called)
|
||||
# {
|
||||
# "id": 0,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "description": None,
|
||||
# "name": "tools",
|
||||
# "parameters": {
|
||||
# "format": "celsius",
|
||||
# "location": "San Francisco, CA",
|
||||
# "num_days": 3,
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
```
|
||||
- If you are using the `/generate` with a `grammar` it is recommended to include the grammar in the prompt prefixed by something like `Please use the following JSON schema to generate the output:`. This will help the model understand the context of the grammar and generate the output accordingly.
|
||||
- If you are getting a response with many repeated tokens, please use the `frequency_penalty` or `repetition_penalty` to reduce the number of repeated tokens in the output.
|
||||
|
@ -7,14 +7,17 @@ pub(crate) struct Env {
|
||||
git_sha: &'static str,
|
||||
docker_label: &'static str,
|
||||
nvidia_env: String,
|
||||
xpu_env: String,
|
||||
}
|
||||
|
||||
impl Env {
|
||||
pub fn new() -> Self {
|
||||
let nvidia_env = nvidia_smi();
|
||||
let xpu_env = xpu_smi();
|
||||
|
||||
Self {
|
||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
|
||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||
@ -31,7 +34,8 @@ impl fmt::Display for Env {
|
||||
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
|
||||
let output = nvidia_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
||||
fn xpu_smi() -> Option<String> {
|
||||
let output = Command::new("xpu-smi").arg("discovery").output().ok()?;
|
||||
let xpu_smi = String::from_utf8(output.stdout).ok()?;
|
||||
let output = xpu_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
@ -251,7 +251,7 @@ struct Args {
|
||||
///
|
||||
/// This setting is only applied if there is room in the batch
|
||||
/// as defined by `max_batch_total_tokens`.
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
#[clap(default_value = "0.3", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
|
||||
/// Limits the number of tokens for the prefill operation.
|
||||
@ -448,6 +448,8 @@ fn shard_manager(
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
@ -512,6 +514,7 @@ fn shard_manager(
|
||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||
};
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
shard_args.push("--otlp-endpoint".to_string());
|
||||
@ -564,6 +567,14 @@ fn shard_manager(
|
||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||
}
|
||||
|
||||
envs.push((
|
||||
"MAX_TOTAL_TOKENS".into(),
|
||||
max_total_tokens.to_string().into(),
|
||||
));
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
@ -965,6 +976,7 @@ fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
@ -996,6 +1008,7 @@ fn spawn_shards(
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
let max_batch_size = args.max_batch_size;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -1018,6 +1031,8 @@ fn spawn_shards(
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
@ -1228,7 +1243,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
||||
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
|
||||
|
||||
tracing::info!("Waiting for {process_name} to gracefully shutdown");
|
||||
|
||||
while terminate_time.elapsed() < timeout {
|
||||
if let Some(status) = process.try_wait()? {
|
||||
tracing::info!("{process_name} terminated");
|
||||
@ -1236,7 +1250,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
|
||||
tracing::info!("Killing {process_name}");
|
||||
|
||||
process.kill()?;
|
||||
@ -1271,7 +1284,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
tracing::info!("{}", env_runtime);
|
||||
}
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
tracing::info!("{:#?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
@ -1304,7 +1317,12 @@ fn main() -> Result<(), LauncherError> {
|
||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
@ -1376,8 +1394,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||
(Some(cuda_graphs), None) => cuda_graphs.clone(),
|
||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
@ -1472,6 +1489,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
|
9
load_tests/Makefile
Normal file
9
load_tests/Makefile
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
ShareGPT_V3_unfiltered_cleaned_split.json:
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
prepare_share: ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
python filter.py
|
||||
|
||||
prepare_orca:
|
||||
python orca.py
|
@ -26,23 +26,23 @@ export function get_options() {
|
||||
// }],
|
||||
},
|
||||
scenarios: {
|
||||
single_user: {
|
||||
// single_user: {
|
||||
// executor: 'constant-arrival-rate',
|
||||
// duration: '60s',
|
||||
// preAllocatedVUs: 1,
|
||||
// rate: 20,
|
||||
// timeUnit: '1s',
|
||||
// },
|
||||
load_test: {
|
||||
executor: 'constant-arrival-rate',
|
||||
duration: '60s',
|
||||
preAllocatedVUs: 1,
|
||||
preAllocatedVUs: 100,
|
||||
rate: 1,
|
||||
timeUnit: '1s',
|
||||
},
|
||||
// load_test: {
|
||||
// executor: 'constant-arrival-rate',
|
||||
// duration: '60s',
|
||||
// preAllocatedVUs: 100,
|
||||
// rate: 1,
|
||||
// timeUnit: '1s',
|
||||
// },
|
||||
// breakpoint: {
|
||||
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
||||
// preAllocatedVUs: 1000,
|
||||
// preAllocatedVUs: 300,
|
||||
// stages: [
|
||||
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
||||
// ],
|
||||
|
26
load_tests/filter.py
Normal file
26
load_tests/filter.py
Normal file
@ -0,0 +1,26 @@
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Select only the first 2k conversations that start with a human.
|
||||
max = 2000
|
||||
conversations = []
|
||||
for conversation in data:
|
||||
conv = conversation.get("conversations")
|
||||
if conv and conv[0]["from"] == "human":
|
||||
# Trim the rest of the output
|
||||
conversation["conversations"] = conversation["conversations"][:1]
|
||||
conversations.append(conversation)
|
||||
|
||||
if len(conversation) >= max:
|
||||
break
|
||||
|
||||
with open("./small.json", "w") as f:
|
||||
data = json.dump(conversations, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
27
load_tests/orca.py
Normal file
27
load_tests/orca.py
Normal file
@ -0,0 +1,27 @@
|
||||
import json
|
||||
import datasets
|
||||
import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
|
||||
# Select only the first 2k conversations that start with a human.
|
||||
max = min(2000, len(dataset))
|
||||
conversations = []
|
||||
for item in tqdm.tqdm(dataset, total=max):
|
||||
conversation = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": item["question"]},
|
||||
],
|
||||
"id": item["id"],
|
||||
}
|
||||
conversations.append(conversation)
|
||||
if len(conversations) >= max:
|
||||
break
|
||||
|
||||
with open("./small.json", "w") as f:
|
||||
data = json.dump(conversations, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,63 +0,0 @@
|
||||
import {check} from 'k6';
|
||||
import http from 'k6/http';
|
||||
import {Trend} from 'k6/metrics';
|
||||
|
||||
const host = __ENV.HOST || '127.0.0.1:3000';
|
||||
|
||||
const totalTime = new Trend('total_time', true);
|
||||
const validationTime = new Trend('validation_time', true);
|
||||
const queueTime = new Trend('queue_time', true);
|
||||
const inferenceTime = new Trend('inference_time', true);
|
||||
const timePerToken = new Trend('time_per_token', true);
|
||||
|
||||
const example = {
|
||||
payload: JSON.stringify({
|
||||
inputs: '# This is a fibonacci function written in the Python programming language.' +
|
||||
'def fibonacci',
|
||||
parameters: {
|
||||
details: true,
|
||||
max_new_tokens: 60,
|
||||
temperature: 0.2,
|
||||
top_p: 0.95,
|
||||
seed: 0,
|
||||
},
|
||||
}),
|
||||
generated_tokens: 60
|
||||
};
|
||||
|
||||
export const options = {
|
||||
thresholds: {
|
||||
http_req_failed: ['rate==0'],
|
||||
time_per_token: ['p(95)<90'],
|
||||
queue_time: ['p(95)<1500'],
|
||||
},
|
||||
scenarios: {
|
||||
load_test: {
|
||||
executor: 'constant-arrival-rate',
|
||||
duration: '60s',
|
||||
preAllocatedVUs: 100,
|
||||
rate: 10,
|
||||
timeUnit: '1s',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default function () {
|
||||
const headers = {'Content-Type': 'application/json'};
|
||||
const res = http.post(`http://${host}/generate`, example.payload, {
|
||||
headers,
|
||||
});
|
||||
|
||||
check(res, {
|
||||
'Post status is 200': (r) => res.status === 200,
|
||||
'Post response generated tokens': (r) => res.status === 200 && res.json().details.generated_tokens === example.generated_tokens,
|
||||
});
|
||||
|
||||
if (res.status === 200) {
|
||||
totalTime.add(res.headers["X-Total-Time"]);
|
||||
validationTime.add(res.headers["X-Validation-Time"]);
|
||||
queueTime.add(res.headers["X-Queue-Time"]);
|
||||
inferenceTime.add(res.headers["X-Inference-Time"]);
|
||||
timePerToken.add(res.headers["X-Time-Per-Token"]);
|
||||
}
|
||||
}
|
@ -884,12 +884,75 @@ pub(crate) struct ToolCall {
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct Text {
|
||||
#[serde(default)]
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ImageUrl {
|
||||
#[serde(default)]
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct Content {
|
||||
pub r#type: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub image_url: Option<ImageUrl>,
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Array(arr) => {
|
||||
let results: Result<Vec<String>, _> = arr
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let content: Content =
|
||||
serde_json::from_value(v).map_err(de::Error::custom)?;
|
||||
match content.r#type.as_str() {
|
||||
"text" => Ok(content.text.unwrap_or_default()),
|
||||
"image_url" => {
|
||||
if let Some(url) = content.image_url {
|
||||
Ok(format!("", url.url))
|
||||
} else {
|
||||
Ok(String::new())
|
||||
}
|
||||
}
|
||||
_ => Err(de::Error::custom("invalid content type")),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.map(|strings| Some(strings.join("")))
|
||||
}
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
|
@ -1,10 +1,10 @@
|
||||
vllm-cuda:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
||||
git clone https://github.com/Narsil/vllm.git vllm
|
||||
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
|
||||
cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
|
922
server/poetry.lock
generated
922
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
@ -31,10 +31,12 @@ einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = { version = "^0.10", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
torch = { version = "^2.3.0", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines= { version = "^0.0.36", optional = true }
|
||||
prometheus-client = "^0.20.0"
|
||||
py-cpuinfo = "^9.0.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
|
@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
# Will be set in warmup
|
||||
@ -24,7 +25,10 @@ class CacheManager:
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
if IS_XPU_SYSTEM:
|
||||
x = 1
|
||||
else:
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
|
@ -21,8 +21,10 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
if not IS_XPU_SYSTEM:
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
|
@ -24,7 +24,10 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
if not IS_XPU_SYSTEM:
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
@ -34,6 +34,11 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -753,7 +758,10 @@ class FlashCausalLM(Model):
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
torch.cuda.empty_cache()
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.empty_cache()
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
@ -779,7 +787,10 @@ class FlashCausalLM(Model):
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
torch.cuda.synchronize(self.device)
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.synchronize(self.device)
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.synchronize(self.device)
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
@ -787,12 +798,20 @@ class FlashCausalLM(Model):
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(
|
||||
self.device
|
||||
).total_memory
|
||||
|
||||
free_memory = max(
|
||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||
)
|
||||
free_memory = max(
|
||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||
)
|
||||
elif IS_XPU_SYSTEM:
|
||||
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||
free_memory = int(total_gpu_memory * 0.5)
|
||||
else:
|
||||
raise NotImplementedError("FlashModel is only available on GPU")
|
||||
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
|
@ -18,6 +18,8 @@ from text_generation_server.utils import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
@ -33,6 +35,9 @@ class FlashLlama(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
|
@ -33,8 +33,9 @@ tracer = trace.get_tracer(__name__)
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
||||
@ -120,6 +121,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
@ -316,6 +322,9 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||
|
||||
@ -511,33 +520,18 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
|
||||
if cu_seqlen_prefill is None:
|
||||
logits, speculative_logits = self.compiled_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
else:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -32,6 +33,9 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers.models.qwen2 import Qwen2Tokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from transformers.models.qwen2 import Qwen2Config
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = Qwen2Config.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -33,6 +34,9 @@ class FlashRWSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
|
@ -18,6 +18,8 @@ from text_generation_server.utils import (
|
||||
Weights,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@ -35,6 +37,9 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None and cuda_graphs != "0":
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
except Exception as e:
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import signal
|
||||
|
||||
from grpc import aio
|
||||
from loguru import logger
|
||||
@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
KEEP_PROCESSING = True
|
||||
|
||||
def __init__(self):
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self, signum, frame):
|
||||
print(f"Exiting gracefully: Signal {signum}")
|
||||
self.KEEP_PROCESSING = False
|
||||
|
||||
|
||||
signal_handler = SignalHandler()
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -57,7 +57,14 @@ def initialize_torch_distributed():
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
else:
|
||||
backend = "gloo"
|
||||
try:
|
||||
import oneccl_bindings_for_pytorch
|
||||
|
||||
backend = "ccl"
|
||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
||||
except ImportError:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
if WORLD_SIZE == 1:
|
||||
|
@ -2,8 +2,13 @@ import os
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
import math
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.flash_attn_triton import triton_attention
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
@ -113,6 +118,28 @@ def attention(
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM and HAS_FLASH_ATTN_V2_CUDA:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
|
@ -1,4 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def is_xpu_available():
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_XPU_SYSTEM = is_xpu_available()
|
||||
|
@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
@ -18,7 +20,14 @@ except ImportError:
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -497,7 +506,7 @@ class MedusaModel(torch.nn.Module):
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
for i in range(get_speculate())
|
||||
]
|
||||
)
|
||||
|
||||
@ -594,7 +603,7 @@ class MedusaHeadV2(nn.Module):
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||
self.n_medusa_heads = get_speculate()
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
@ -872,7 +881,15 @@ try:
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if IS_XPU_SYSTEM:
|
||||
res_out = hidden_states
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||
)
|
||||
if residual is not None:
|
||||
res_out = residual
|
||||
return out, res_out
|
||||
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -918,7 +935,20 @@ try:
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if IS_XPU_SYSTEM:
|
||||
residual_out = hidden_states
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
)
|
||||
if residual is not None:
|
||||
residual_out = residual
|
||||
return out, residual_out
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -988,7 +1018,7 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm._C import ops
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
@ -1044,6 +1074,10 @@ try:
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif IS_XPU_SYSTEM:
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
@ -1163,6 +1197,7 @@ try:
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
|
@ -151,7 +151,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
Frequency penalty as defined by OpenAI
|
||||
Frequency penalty as defined by OpenAI in
|
||||
https://platform.openai.com/docs/guides/text-generation/parameter-details
|
||||
|
||||
Args:
|
||||
frequency_penalty (`List[float]`):
|
||||
@ -165,15 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||
score = -torch.where(
|
||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||
)
|
||||
# set score to 0 where input_ids is a padding token
|
||||
score *= input_ids.ne(0)
|
||||
batch_size, input_size = input_ids.size()
|
||||
vocab_size = scores.size(1)
|
||||
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
# Calculate the frequency for each token so far
|
||||
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||
token_freq.scatter_add_(
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||
)
|
||||
token_freq /= input_size
|
||||
|
||||
# Apply the frequency penalty to logits
|
||||
scores -= token_freq * self.penalty_tensor
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.penalty = [self.penalty[i] for i in indices]
|
||||
|
@ -1,9 +1,13 @@
|
||||
import torch
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
# TODO: check is this is OK with XPU or we should guard it
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
@ -11,6 +15,9 @@ except Exception as e:
|
||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -18,9 +25,14 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
if IS_XPU_SYSTEM:
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def attention(
|
||||
@ -55,6 +67,22 @@ def attention(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
if IS_XPU_SYSTEM:
|
||||
query = query.contiguous()
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
|
Loading…
Reference in New Issue
Block a user