mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Merge pull request #158 from kdamaszk/rebase-tgi-2-0-2
Rebase with TGI 2.0.2
This commit is contained in:
commit
3bf8e8e466
93
.github/workflows/build.yaml
vendored
93
.github/workflows/build.yaml
vendored
@ -274,12 +274,105 @@ jobs:
|
|||||||
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||||
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
|
||||||
|
|
||||||
|
build-and-push-image-intel:
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- build-and-push-image # Wait for the main docker image to be built
|
||||||
|
- integration-tests # Wait for the main integration-tests
|
||||||
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
# This is used to complete the identity challenge
|
||||||
|
# with sigstore/fulcio when running outside of PRs.
|
||||||
|
id-token: write
|
||||||
|
security-events: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Initialize Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2.0.0
|
||||||
|
with:
|
||||||
|
install: true
|
||||||
|
- name: Inject slug/short variables
|
||||||
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
|
- name: Tailscale
|
||||||
|
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
|
||||||
|
with:
|
||||||
|
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Login to internal Container Registry
|
||||||
|
uses: docker/login-action@v2.1.0
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
|
||||||
|
registry: registry.internal.huggingface.tech
|
||||||
|
- name: Login to Azure Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v2.1.0
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
|
||||||
|
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
|
||||||
|
# If pull request
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name == 'pull_request' }}
|
||||||
|
id: meta-pr
|
||||||
|
uses: docker/metadata-action@v4.3.0
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
tags: |
|
||||||
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||||
|
# If main, release or tag
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
if: ${{ github.event_name != 'pull_request' }}
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v4.3.0
|
||||||
|
with:
|
||||||
|
flavor: |
|
||||||
|
latest=false
|
||||||
|
images: |
|
||||||
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
|
ghcr.io/huggingface/text-generation-inference
|
||||||
|
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
|
tags: |
|
||||||
|
type=semver,pattern={{version}}-intel
|
||||||
|
type=semver,pattern={{major}}.{{minor}}-intel
|
||||||
|
type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||||
|
- name: Build and push Docker image
|
||||||
|
id: build-and-push
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: Dockerfile_intel
|
||||||
|
push: true
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
build-args: |
|
||||||
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel
|
||||||
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
||||||
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
needs:
|
needs:
|
||||||
- start-runner
|
- start-runner
|
||||||
- build-and-push-image
|
- build-and-push-image
|
||||||
- build-and-push-image-rocm
|
- build-and-push-image-rocm
|
||||||
|
- build-and-push-image-intel
|
||||||
- integration-tests
|
- integration-tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ server/exllama_kernels/exllama_kernels/hip_func/
|
|||||||
*_hip.cuh
|
*_hip.cuh
|
||||||
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||||
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||||
|
|
||||||
|
data/
|
||||||
|
@ -9,7 +9,7 @@ members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
105
Dockerfile_intel
Normal file
105
Dockerfile_intel
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef as planner
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
|
||||||
|
# Text Generation Inference base image for Intel
|
||||||
|
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
|
||||||
|
|
||||||
|
USER root
|
||||||
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
|
||||||
|
|
||||||
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
|
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
# Build pytorch and ipex
|
||||||
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b xpu_main origin/xpu-main
|
||||||
|
RUN git clone https://github.com/pytorch/pytorch.git && cd pytorch && git checkout 209f2fa8ff86652f67d75c2f19bf9cb9942fd018 && git apply /usr/src/intel-extension-for-pytorch/torch_patches/00*.patch
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_cuda.txt && \
|
||||||
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
|
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||||
|
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||||
|
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||||
|
ENV DIAGUTIL_PATH=/opt/intel/oneapi/compiler/latest/etc/compiler/sys_check/sys_check.sh
|
||||||
|
ENV CCL_CONFIGURATION=cpu_gpu_dpcpp
|
||||||
|
ENV MANPATH=/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/compiler/latest/share/man
|
||||||
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
|
ENV CMPLR_ROOT=/opt/intel/oneapi/compiler/latest
|
||||||
|
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||||
|
ENV OCL_ICD_FILENAMES=libintelocl_emu.so:libalteracl.so:/opt/intel/oneapi/compiler/latest/lib/libintelocl.so
|
||||||
|
ENV CLASSPATH=/opt/intel/oneapi/mpi/latest/share/java/mpi.jar:/opt/intel/oneapi/mpi/latest/share/java/mpi.jar
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||||
|
ENV MKLROOT=/opt/intel/oneapi/mkl/latest
|
||||||
|
ENV NLSPATH=/opt/intel/oneapi/mkl/latest/share/locale/%l_%t/%N:/opt/intel/oneapi/compiler/latest/lib/locale/%l_%t/%N
|
||||||
|
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
|
||||||
|
|
||||||
|
RUN pip uninstall -y torch && cd pytorch && git submodule update --init --recursive && python setup.py install
|
||||||
|
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=ON BUILD_WITH_CPU=ON USE_XETLA=ON python setup.py install
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# Final image
|
||||||
|
FROM base
|
||||||
|
|
||||||
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
CMD ["--json-output"]
|
@ -80,6 +80,7 @@ class Client:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -119,6 +120,8 @@ class Client:
|
|||||||
higher are kept for generation
|
higher are kept for generation
|
||||||
tools (`List[Tool]`):
|
tools (`List[Tool]`):
|
||||||
List of tools to use
|
List of tools to use
|
||||||
|
tool_prompt (`str`):
|
||||||
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
|
||||||
@ -139,6 +142,7 @@ class Client:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
@ -466,6 +470,7 @@ class AsyncClient:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
@ -505,6 +510,8 @@ class AsyncClient:
|
|||||||
higher are kept for generation
|
higher are kept for generation
|
||||||
tools (`List[Tool]`):
|
tools (`List[Tool]`):
|
||||||
List of tools to use
|
List of tools to use
|
||||||
|
tool_prompt (`str`):
|
||||||
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
|
||||||
@ -525,6 +532,7 @@ class AsyncClient:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
|
@ -159,6 +159,8 @@ class ChatRequest(BaseModel):
|
|||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
# List of tools to be used
|
# List of tools to be used
|
||||||
tools: Optional[List[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
|
# A prompt to be appended before the tools
|
||||||
|
tool_prompt: Optional[str] = None
|
||||||
# Choice of tool to be used
|
# Choice of tool to be used
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
|
|
||||||
|
@ -25,6 +25,10 @@
|
|||||||
title: Non-core Model Serving
|
title: Non-core Model Serving
|
||||||
- local: basic_tutorials/safety
|
- local: basic_tutorials/safety
|
||||||
title: Safety
|
title: Safety
|
||||||
|
- local: basic_tutorials/using_guidance
|
||||||
|
title: Using Guidance, JSON, tools
|
||||||
|
- local: basic_tutorials/visual_language_models
|
||||||
|
title: Visual Language Models
|
||||||
title: Tutorials
|
title: Tutorials
|
||||||
- sections:
|
- sections:
|
||||||
- local: conceptual/streaming
|
- local: conceptual/streaming
|
||||||
@ -42,5 +46,6 @@
|
|||||||
- local: conceptual/speculation
|
- local: conceptual/speculation
|
||||||
title: Speculation (Medusa, ngram)
|
title: Speculation (Medusa, ngram)
|
||||||
- local: conceptual/guidance
|
- local: conceptual/guidance
|
||||||
title: Guidance, JSON, tools (using outlines)
|
title: How Guidance Works (via outlines)
|
||||||
|
|
||||||
title: Conceptual Guides
|
title: Conceptual Guides
|
||||||
|
@ -162,7 +162,7 @@ Options:
|
|||||||
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
|
This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.
|
||||||
|
|
||||||
[env: WAITING_SERVED_RATIO=]
|
[env: WAITING_SERVED_RATIO=]
|
||||||
[default: 1.2]
|
[default: 0.3]
|
||||||
|
|
||||||
```
|
```
|
||||||
## MAX_BATCH_PREFILL_TOKENS
|
## MAX_BATCH_PREFILL_TOKENS
|
||||||
|
419
docs/source/basic_tutorials/using_guidance.md
Normal file
419
docs/source/basic_tutorials/using_guidance.md
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# Guidance
|
||||||
|
|
||||||
|
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
|
||||||
|
|
||||||
|
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||||
|
|
||||||
|
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
TGI leverages the [outlines](https://github.com/outlines-dev/outlines) library to efficiently parse and compile the grammatical structures and tools specified by users. This integration transforms the defined grammars into an intermediate representation that acts as a framework to guide and constrain content generation, ensuring that outputs adhere to the specified grammatical rules.
|
||||||
|
|
||||||
|
If you are interested in the technical details on how outlines is used in TGI, you can check out the [conceptual guidance documentation](../conceptual/guidance).
|
||||||
|
|
||||||
|
## Table of Contents 📚
|
||||||
|
|
||||||
|
### Grammar and Constraints
|
||||||
|
|
||||||
|
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||||
|
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||||
|
- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
|
||||||
|
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||||
|
|
||||||
|
### Tools and Functions
|
||||||
|
|
||||||
|
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||||
|
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||||
|
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||||
|
|
||||||
|
## Grammar and Constraints 🛣️
|
||||||
|
|
||||||
|
### The Grammar Parameter
|
||||||
|
|
||||||
|
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the LLM.
|
||||||
|
|
||||||
|
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl localhost:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||||
|
"parameters": {
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"activity": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"animals_seen": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 5
|
||||||
|
},
|
||||||
|
"animals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "activity", "animals_seen", "animals"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The LLM will then generate a response that conforms to the specified grammar.
|
||||||
|
|
||||||
|
> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||||
|
|
||||||
|
### Constrain with Pydantic
|
||||||
|
|
||||||
|
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, conint
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class Animals(BaseModel):
|
||||||
|
location: str
|
||||||
|
activity: str
|
||||||
|
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||||
|
animals: List[str]
|
||||||
|
|
||||||
|
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": Animals.schema()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://127.0.0.1:3000/generate',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### JSON Schema Integration
|
||||||
|
|
||||||
|
If Pydantic's not your style, go raw with direct JSON Schema integration. This is similar to the first example but with programmatic control.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
json_schema = {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"activity": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"animals_seen": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 5
|
||||||
|
},
|
||||||
|
"animals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "activity", "animals_seen", "animals"]
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 200,
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": json_schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://127.0.0.1:3000/generate',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the client
|
||||||
|
|
||||||
|
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
# Define an async function to encapsulate the async operation
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient(base_url="http://localhost:3000")
|
||||||
|
|
||||||
|
# Use 'await' to wait for the async method 'chat' to complete
|
||||||
|
response = await client.generate(
|
||||||
|
"Whats Googles DNS",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=1,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex,
|
||||||
|
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once the response is received, you can process it
|
||||||
|
print(response.generated_text)
|
||||||
|
|
||||||
|
# Ensure the main async function is run in the event loop
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# 118.8.0.84
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tools and Functions 🛠️
|
||||||
|
|
||||||
|
### The Tools Parameter
|
||||||
|
|
||||||
|
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||||
|
|
||||||
|
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||||
|
|
||||||
|
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl localhost:3000/v1/chat/completions \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in New York?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "get_current_weather"
|
||||||
|
}'
|
||||||
|
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Text Generation Inference Client
|
||||||
|
|
||||||
|
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
# Define an async function to encapsulate the async operation
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient(base_url="http://localhost:3000")
|
||||||
|
|
||||||
|
# Use 'await' to wait for the async method 'chat' to complete
|
||||||
|
response = await client.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once the response is received, you can process it
|
||||||
|
print(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
|
# Ensure the main async function is run in the event loop
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Tools used in example above</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_n_day_weather_forecast",
|
||||||
|
"description": "Get an N-day weather forecast",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
"num_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of days to forecast",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format", "num_days"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### OpenAI integration
|
||||||
|
|
||||||
|
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||||
|
|
||||||
|
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Initialize the client, pointing it to one of the available models
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:3000/v1",
|
||||||
|
api_key="_",
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto", # tool selected by model
|
||||||
|
max_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
called = chat_completion.choices[0].message.tool_calls
|
||||||
|
print(called)
|
||||||
|
# {
|
||||||
|
# "id": 0,
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "description": None,
|
||||||
|
# "name": "tools",
|
||||||
|
# "parameters": {
|
||||||
|
# "format": "celsius",
|
||||||
|
# "location": "San Francisco, CA",
|
||||||
|
# "num_days": 3,
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
```
|
170
docs/source/basic_tutorials/visual_language_models.md
Normal file
170
docs/source/basic_tutorials/visual_language_models.md
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
# Vision Language Model Inference in TGI
|
||||||
|
|
||||||
|
Visual Language Model (VLM) are models that consume both image and text inputs to generate text.
|
||||||
|
|
||||||
|
VLM's are trained on a combination of image and text data and can handle a wide range of tasks, such as image captioning, visual question answering, and visual dialog.
|
||||||
|
|
||||||
|
> What distinguishes VLMs from other text and image models is their ability to handle long context and generate text that is coherent and relevant to the image even after multiple turns or in some cases, multiple images.
|
||||||
|
|
||||||
|
Below are couple of common use cases for vision language models:
|
||||||
|
|
||||||
|
- **Image Captioning**: Given an image, generate a caption that describes the image.
|
||||||
|
- **Visual Question Answering (VQA)**: Given an image and a question about the image, generate an answer to the question.
|
||||||
|
- **Mulimodal Dialog**: Generate response to multiple turns of images and conversations.
|
||||||
|
- **Image Information Retrieval**: Given an image, retrieve information from the image.
|
||||||
|
|
||||||
|
## How to Use a Vision Language Model?
|
||||||
|
|
||||||
|
### Hugging Face Hub Python Library
|
||||||
|
|
||||||
|
To infer with vision language models through Python, you can use the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The `InferenceClient` class provides a simple way to interact with the [Inference API](https://huggingface.co/docs/api-inference/index). Images can be passed as URLs or base64-encoded strings. The `InferenceClient` will automatically detect the image format.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
client = InferenceClient("http://127.0.0.1:3000")
|
||||||
|
image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
prompt = f"What is this a picture of?\n\n"
|
||||||
|
for token in client.text_generation(prompt, max_new_tokens=16, stream=True):
|
||||||
|
print(token)
|
||||||
|
|
||||||
|
# This is a picture of an anthropomorphic rabbit in a space suit.
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import io
|
||||||
|
|
||||||
|
client = InferenceClient("http://127.0.0.1:3000")
|
||||||
|
|
||||||
|
# read image from local file
|
||||||
|
image_path = "rabbit.png"
|
||||||
|
with open(image_path, "rb") as f:
|
||||||
|
image = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
|
||||||
|
image = f"data:image/png;base64,{image}"
|
||||||
|
prompt = f"What is this a picture of?\n\n"
|
||||||
|
|
||||||
|
for token in client.text_generation(prompt, max_new_tokens=10, stream=True):
|
||||||
|
print(token)
|
||||||
|
|
||||||
|
# This is a picture of an anthropomorphic rabbit in a space suit.
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text.
|
||||||
|
|
||||||
|
### Inference Through Sending `cURL` Requests
|
||||||
|
|
||||||
|
To use the `generate_stream` endpoint with curl, you can add the `-N` flag. This flag disables curl default buffering and shows data as it arrives from the server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N 127.0.0.1:3000/generate_stream \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"What is this a picture of?\n\n","parameters":{"max_new_tokens":16, "seed": 42}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
|
||||||
|
# ...
|
||||||
|
# data:{"index":16,"token":{"id":28723,"text":".","logprob":-0.6196289,"special":false},"generated_text":"This is a picture of an anthropomorphic rabbit in a space suit.","details":null}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inference Through JavaScript
|
||||||
|
|
||||||
|
First, we need to install the `@huggingface/inference` library.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install @huggingface/inference
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're using the free Inference API, you can use [Huggingface.js](https://huggingface.co/docs/huggingface.js/inference/README)'s `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint` class to easily interact with the Inference API.
|
||||||
|
|
||||||
|
We can create a `HfInferenceEndpoint` providing our endpoint URL and We can create a `HfInferenceEndpoint` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens).
|
||||||
|
|
||||||
|
```js
|
||||||
|
import { HfInferenceEndpoint } from "@huggingface/inference";
|
||||||
|
|
||||||
|
const hf = new HfInferenceEndpoint("http://127.0.0.1:3000", "HF_TOKEN");
|
||||||
|
|
||||||
|
const prompt =
|
||||||
|
"What is this a picture of?\n\n";
|
||||||
|
|
||||||
|
const stream = hf.textGenerationStream({
|
||||||
|
inputs: prompt,
|
||||||
|
parameters: { max_new_tokens: 16, seed: 42 },
|
||||||
|
});
|
||||||
|
for await (const r of stream) {
|
||||||
|
// yield the generated token
|
||||||
|
process.stdout.write(r.token.text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a picture of an anthropomorphic rabbit in a space suit.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Combining Vision Language Models with Other Features
|
||||||
|
|
||||||
|
VLMs in TGI have several advantages, for example these models can be used in tandem with other features for more complex tasks. For example, you can use VLMs with [Guided Generation](/docs/conceptual/guided-generation) to generate specific JSON data from an image.
|
||||||
|
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
width="400"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
For example we can extract information from the rabbit image and generate a JSON object with the location, activity, number of animals seen, and the animals seen. That would look like this:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"activity": "Standing",
|
||||||
|
"animals": ["Rabbit"],
|
||||||
|
"animals_seen": 1,
|
||||||
|
"location": "Rocky surface with mountains in the background and a red light on the rabbit's chest"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
All we need to do is provide a JSON schema to the VLM model and it will generate the JSON object for us.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl localhost:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs":"What is this a picture of?\n\n",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 100,
|
||||||
|
"seed": 42,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"activity": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"animals_seen": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 5
|
||||||
|
},
|
||||||
|
"animals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "activity", "animals_seen", "animals"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "generated_text": "{ \"activity\": \"Standing\", \"animals\": [ \"Rabbit\" ], \"animals_seen\": 1, \"location\": \"Rocky surface with mountains in the background and a red light on the rabbit's chest\" }"
|
||||||
|
# }
|
||||||
|
```
|
||||||
|
|
||||||
|
Want to learn more about how Vision Language Models work? Check out the [awesome blog post on the topic](https://huggingface.co/blog/vlms).
|
@ -1,419 +1,86 @@
|
|||||||
# Guidance
|
# Guidance
|
||||||
|
|
||||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
|
## What is Guidance?
|
||||||
|
|
||||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format.
|
||||||
|
|
||||||
## Quick Start
|
## How is it used?
|
||||||
|
|
||||||
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
|
Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
||||||
|
|
||||||
If you're not up to date, grab the latest version and let's get started!
|
Technically, guidance can be used to generate:
|
||||||
|
|
||||||
## Table of Contents 📚
|
- a specific JSON object
|
||||||
|
- a function signature
|
||||||
|
- typed output like a list of integers
|
||||||
|
|
||||||
### Grammar and Constraints
|
However these use cases can span a wide range of applications, such as:
|
||||||
|
|
||||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
- extracting structured data from unstructured text
|
||||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
- summarizing text into a specific format
|
||||||
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
|
- limit output to specific classes of words (act as a LLM powered classifier)
|
||||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
- generate the input to specific APIs or services
|
||||||
|
- provide reliable and consistent output for downstream tasks
|
||||||
|
- extract data from multimodal inputs
|
||||||
|
|
||||||
### Tools and Functions
|
## How it works?
|
||||||
|
|
||||||
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
Diving into the details, guidance is enabled by including a grammar with a generation request that is compiled, and used to modify the chosen tokens.
|
||||||
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
|
||||||
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
|
||||||
|
|
||||||
## Grammar and Constraints 🛣️
|
This process can be broken down into the following steps:
|
||||||
|
|
||||||
### The Grammar Parameter
|
1. A request is sent to the backend, it is processed and placed in batch. Processing includes compiling the grammar into a finite state machine and a grammar state.
|
||||||
|
|
||||||
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
class="block dark:hidden"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch.gif"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
class="hidden dark:block"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch-dark.gif"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
2. The model does a forward pass over the batch. This returns probabilities for each token in the vocabulary for each request in the batch.
|
||||||
|
|
||||||
```json
|
3. The process of choosing one of those tokens is called `sampling`. The model samples from the distribution of probabilities to choose the next token. In TGI all of the steps before sampling are called `processor`. Grammars are applied as a processor that masks out tokens that are not allowed by the grammar.
|
||||||
curl localhost:3000/generate \
|
|
||||||
-X POST \
|
|
||||||
-H 'Content-Type: application/json' \
|
|
||||||
-d '{
|
|
||||||
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
|
||||||
"parameters": {
|
|
||||||
"repetition_penalty": 1.3,
|
|
||||||
"grammar": {
|
|
||||||
"type": "json",
|
|
||||||
"value": {
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"activity": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"animals_seen": {
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 5
|
|
||||||
},
|
|
||||||
"animals": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["location", "activity", "animals_seen", "animals"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}'
|
|
||||||
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
|
||||||
|
|
||||||
```
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
class="block dark:hidden"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask.gif"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
class="hidden dark:block"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask-dark.gif"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
4. The grammar mask is applied and the model samples from the remaining tokens. Once a token is chosen, we update the grammar state with the new token, to prepare it for the next pass.
|
||||||
|
|
||||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
class="block dark:hidden"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits.gif"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
class="hidden dark:block"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits-dark.gif"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
### Constrain with Pydantic
|
## How to use Guidance?
|
||||||
|
|
||||||
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
|
There are two main ways to use guidance; you can either use the `/generate` endpoint with a grammar or use the `/chat/completion` endpoint with tools.
|
||||||
|
|
||||||
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
Under the hood tools are a special case of grammars that allows the model to choose one or none of the provided tools.
|
||||||
|
|
||||||
```python
|
Please refer to [using guidance](../basic_tutorial/using_guidance) for more examples and details on how to use guidance in Python, JavaScript, and cURL.
|
||||||
import requests
|
|
||||||
from pydantic import BaseModel, conint
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
class Animals(BaseModel):
|
### Getting the most out of guidance
|
||||||
location: str
|
|
||||||
activity: str
|
|
||||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
|
||||||
animals: List[str]
|
|
||||||
|
|
||||||
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
Depending on how you are using guidance, you may want to make use of different features. Here are some tips to get the most out of guidance:
|
||||||
|
|
||||||
data = {
|
- If you are using the `/generate` with a `grammar` it is recommended to include the grammar in the prompt prefixed by something like `Please use the following JSON schema to generate the output:`. This will help the model understand the context of the grammar and generate the output accordingly.
|
||||||
"inputs": prompt,
|
- If you are getting a response with many repeated tokens, please use the `frequency_penalty` or `repetition_penalty` to reduce the number of repeated tokens in the output.
|
||||||
"parameters": {
|
|
||||||
"repetition_penalty": 1.3,
|
|
||||||
"grammar": {
|
|
||||||
"type": "json",
|
|
||||||
"value": Animals.schema()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
'http://127.0.0.1:3000/generate',
|
|
||||||
headers=headers,
|
|
||||||
json=data
|
|
||||||
)
|
|
||||||
print(response.json())
|
|
||||||
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### JSON Schema Integration
|
|
||||||
|
|
||||||
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import requests
|
|
||||||
|
|
||||||
json_schema = {
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"activity": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"animals_seen": {
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 5
|
|
||||||
},
|
|
||||||
"animals": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["location", "activity", "animals_seen", "animals"]
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
|
|
||||||
"parameters": {
|
|
||||||
"max_new_tokens": 200,
|
|
||||||
"repetition_penalty": 1.3,
|
|
||||||
"grammar": {
|
|
||||||
"type": "json",
|
|
||||||
"value": json_schema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
'http://127.0.0.1:3000/generate',
|
|
||||||
headers=headers,
|
|
||||||
json=data
|
|
||||||
)
|
|
||||||
print(response.json())
|
|
||||||
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### Using the client
|
|
||||||
|
|
||||||
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from text_generation import AsyncClient
|
|
||||||
from text_generation.types import GrammarType
|
|
||||||
|
|
||||||
# NOTE: tools defined above and removed for brevity
|
|
||||||
|
|
||||||
# Define an async function to encapsulate the async operation
|
|
||||||
async def main():
|
|
||||||
client = AsyncClient(base_url="http://localhost:3000")
|
|
||||||
|
|
||||||
# Use 'await' to wait for the async method 'chat' to complete
|
|
||||||
response = await client.generate(
|
|
||||||
"Whats Googles DNS",
|
|
||||||
max_new_tokens=10,
|
|
||||||
decoder_input_details=True,
|
|
||||||
seed=1,
|
|
||||||
grammar={
|
|
||||||
"type": GrammarType.Regex,
|
|
||||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Once the response is received, you can process it
|
|
||||||
print(response.generated_text)
|
|
||||||
|
|
||||||
# Ensure the main async function is run in the event loop
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
||||||
# 118.8.0.84
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tools and Functions 🛠️
|
|
||||||
|
|
||||||
### The Tools Parameter
|
|
||||||
|
|
||||||
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
|
||||||
|
|
||||||
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
|
|
||||||
|
|
||||||
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
|
||||||
|
|
||||||
```json
|
|
||||||
curl localhost:3000/v1/chat/completions \
|
|
||||||
-X POST \
|
|
||||||
-H 'Content-Type: application/json' \
|
|
||||||
-d '{
|
|
||||||
"model": "tgi",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What is the weather like in New York?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_weather",
|
|
||||||
"description": "Get the current weather",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city and state, e.g. San Francisco, CA"
|
|
||||||
},
|
|
||||||
"format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
"description": "The temperature unit to use. Infer this from the users location."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["location", "format"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tool_choice": "get_current_weather"
|
|
||||||
}'
|
|
||||||
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Tools used in example below</summary>
|
|
||||||
|
|
||||||
```python
|
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_weather",
|
|
||||||
"description": "Get the current weather",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city and state, e.g. San Francisco, CA",
|
|
||||||
},
|
|
||||||
"format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
"description": "The temperature unit to use. Infer this from the users location.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["location", "format"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_n_day_weather_forecast",
|
|
||||||
"description": "Get an N-day weather forecast",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city and state, e.g. San Francisco, CA",
|
|
||||||
},
|
|
||||||
"format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
"description": "The temperature unit to use. Infer this from the users location.",
|
|
||||||
},
|
|
||||||
"num_days": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "The number of days to forecast",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["location", "format", "num_days"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Text Generation Inference Client
|
|
||||||
|
|
||||||
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from text_generation import AsyncClient
|
|
||||||
|
|
||||||
# NOTE: tools defined above and removed for brevity
|
|
||||||
|
|
||||||
# Define an async function to encapsulate the async operation
|
|
||||||
async def main():
|
|
||||||
client = AsyncClient(base_url="http://localhost:3000")
|
|
||||||
|
|
||||||
# Use 'await' to wait for the async method 'chat' to complete
|
|
||||||
response = await client.chat(
|
|
||||||
max_tokens=100,
|
|
||||||
seed=1,
|
|
||||||
tools=tools,
|
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What is the weather like in Brooklyn, New York?",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Once the response is received, you can process it
|
|
||||||
print(response.choices[0].message.tool_calls)
|
|
||||||
|
|
||||||
# Ensure the main async function is run in the event loop
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
||||||
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.3-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### OpenAI integration
|
|
||||||
|
|
||||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
|
||||||
|
|
||||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
# Initialize the client, pointing it to one of the available models
|
|
||||||
client = OpenAI(
|
|
||||||
base_url="http://localhost:3000/v1",
|
|
||||||
api_key="_",
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: tools defined above and removed for brevity
|
|
||||||
|
|
||||||
chat_completion = client.chat.completions.create(
|
|
||||||
model="tgi",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
tools=tools,
|
|
||||||
tool_choice="auto", # tool selected by model
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
called = chat_completion.choices[0].message.tool_calls
|
|
||||||
print(called)
|
|
||||||
# {
|
|
||||||
# "id": 0,
|
|
||||||
# "type": "function",
|
|
||||||
# "function": {
|
|
||||||
# "description": None,
|
|
||||||
# "name": "tools",
|
|
||||||
# "parameters": {
|
|
||||||
# "format": "celsius",
|
|
||||||
# "location": "San Francisco, CA",
|
|
||||||
# "num_days": 3,
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
```
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
## Speculation
|
## Speculation
|
||||||
|
|
||||||
|
|
||||||
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
||||||
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ In order to use medusa models in TGI, simply point to a medusa enabled model, an
|
|||||||
|
|
||||||
|
|
||||||
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
||||||
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
|
N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean".
|
||||||
|
|
||||||
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ Token streaming is the mode in which the server returns the tokens one by one as
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality earlier than the end of the generation. This has different positive effects:
|
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects:
|
||||||
|
|
||||||
* Users can get results orders of magnitude earlier for extremely long queries.
|
* Users can get results orders of magnitude earlier for extremely long queries.
|
||||||
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
|
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
|
||||||
@ -116,7 +116,7 @@ curl -N 127.0.0.1:8080/generate_stream \
|
|||||||
First, we need to install the `@huggingface/inference` library.
|
First, we need to install the `@huggingface/inference` library.
|
||||||
`npm install @huggingface/inference`
|
`npm install @huggingface/inference`
|
||||||
|
|
||||||
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`. Let's
|
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`.
|
||||||
|
|
||||||
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.
|
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.
|
||||||
|
|
||||||
|
@ -18,8 +18,8 @@ Text Generation Inference implements many optimizations and features, such as:
|
|||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
||||||
- Stop sequences
|
- Stop sequences
|
||||||
- Log probabilities
|
- Log probabilities
|
||||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output.
|
|
||||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
|
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
|
||||||
|
- [Guidance](../conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas.
|
||||||
|
|
||||||
Text Generation Inference is used in production by multiple projects, such as:
|
Text Generation Inference is used in production by multiple projects, such as:
|
||||||
|
|
||||||
|
@ -293,6 +293,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -334,6 +335,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
@ -371,6 +375,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -395,6 +400,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -8.5625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.78125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 288,
|
||||||
|
"logprob": -0.2854004,
|
||||||
|
"special": false,
|
||||||
|
"text": "ing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37573242,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 633,
|
||||||
|
"logprob": -0.09301758,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4480,
|
||||||
|
"logprob": -0.3322754,
|
||||||
|
"special": false,
|
||||||
|
"text": " feature"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.8510742,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.13464355,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2039,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " game"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.89990234,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test requesting a new feature in the game.\n\n"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.13000488,
|
||||||
|
"special": false,
|
||||||
|
"text": " A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13088,
|
||||||
|
"logprob": -0.6713867,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.2980957,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6398,
|
||||||
|
"logprob": -0.060638428,
|
||||||
|
"special": false,
|
||||||
|
"text": " sitting"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.27319336,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17972,
|
||||||
|
"logprob": -0.040405273,
|
||||||
|
"special": false,
|
||||||
|
"text": " pile"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": -0.0002708435,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2445,
|
||||||
|
"logprob": -0.095336914,
|
||||||
|
"special": false,
|
||||||
|
"text": " money"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.0068359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " A chicken is sitting on a pile of money."
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
81
integration-tests/models/test_idefics2.py
Normal file
81
integration-tests/models/test_idefics2.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_idefics2_next_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"HuggingFaceM4/idefics2-8b",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||||
|
await flash_idefics2_next_handle.health(300)
|
||||||
|
return flash_idefics2_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " A chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_load(
|
||||||
|
flash_idefics2_next, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
chicken = get_chicken()
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_idefics2_next,
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
generated_texts = [r.generated_text for r in responses]
|
||||||
|
assert generated_texts[0] == " A chicken is sitting on a pile of money."
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -7,14 +7,17 @@ pub(crate) struct Env {
|
|||||||
git_sha: &'static str,
|
git_sha: &'static str,
|
||||||
docker_label: &'static str,
|
docker_label: &'static str,
|
||||||
nvidia_env: String,
|
nvidia_env: String,
|
||||||
|
xpu_env: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Env {
|
impl Env {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let nvidia_env = nvidia_smi();
|
let nvidia_env = nvidia_smi();
|
||||||
|
let xpu_env = xpu_smi();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||||
|
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
|
||||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||||
@ -31,7 +34,8 @@ impl fmt::Display for Env {
|
|||||||
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
||||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||||
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||||
|
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
|
|||||||
let output = nvidia_smi.replace('\n', "\n ");
|
let output = nvidia_smi.replace('\n', "\n ");
|
||||||
Some(output.trim().to_string())
|
Some(output.trim().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn xpu_smi() -> Option<String> {
|
||||||
|
let output = Command::new("xpu-smi").arg("discovery").output().ok()?;
|
||||||
|
let xpu_smi = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let output = xpu_smi.replace('\n', "\n ");
|
||||||
|
Some(output.trim().to_string())
|
||||||
|
}
|
||||||
|
@ -253,7 +253,7 @@ struct Args {
|
|||||||
///
|
///
|
||||||
/// This setting is only applied if there is room in the batch
|
/// This setting is only applied if there is room in the batch
|
||||||
/// as defined by `max_batch_total_tokens`.
|
/// as defined by `max_batch_total_tokens`.
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "0.3", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
|
||||||
/// Limits the number of tokens for the prefill operation.
|
/// Limits the number of tokens for the prefill operation.
|
||||||
@ -435,7 +435,6 @@ fn shard_manager(
|
|||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
speculate: Option<usize>,
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
max_total_tokens: usize,
|
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
@ -451,6 +450,8 @@ fn shard_manager(
|
|||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
@ -515,6 +516,7 @@ fn shard_manager(
|
|||||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// OpenTelemetry
|
// OpenTelemetry
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
shard_args.push("--otlp-endpoint".to_string());
|
shard_args.push("--otlp-endpoint".to_string());
|
||||||
@ -527,9 +529,6 @@ fn shard_manager(
|
|||||||
// Remove LOG_LEVEL if present
|
// Remove LOG_LEVEL if present
|
||||||
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||||
|
|
||||||
// Max total tokens
|
|
||||||
envs.push(("MAX_TOTAL_TOKENS".into(), max_total_tokens.to_string().into()));
|
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
if world_size == 1 {
|
if world_size == 1 {
|
||||||
envs.push(("RANK".into(), rank.to_string().into()));
|
envs.push(("RANK".into(), rank.to_string().into()));
|
||||||
@ -572,6 +571,14 @@ fn shard_manager(
|
|||||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
envs.push((
|
||||||
|
"MAX_TOTAL_TOKENS".into(),
|
||||||
|
max_total_tokens.to_string().into(),
|
||||||
|
));
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||||
|
}
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
@ -680,8 +687,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
terminate("Shard", p, Duration::from_secs(30)).unwrap();
|
terminate("shard", p, Duration::from_secs(90)).unwrap();
|
||||||
tracing::info!("Shard terminated");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -975,13 +981,13 @@ fn spawn_shards(
|
|||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
|
max_total_tokens: usize,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
shutdown_sender: mpsc::Sender<()>,
|
shutdown_sender: mpsc::Sender<()>,
|
||||||
status_receiver: &mpsc::Receiver<ShardStatus>,
|
status_receiver: &mpsc::Receiver<ShardStatus>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
max_total_tokens: usize,
|
|
||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..1 {
|
for rank in 0..1 {
|
||||||
@ -1007,6 +1013,7 @@ fn spawn_shards(
|
|||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
|
let max_batch_size = args.max_batch_size;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1014,7 +1021,6 @@ fn spawn_shards(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
max_total_tokens,
|
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
@ -1030,6 +1036,8 @@ fn spawn_shards(
|
|||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
rope_scaling,
|
rope_scaling,
|
||||||
rope_factor,
|
rope_factor,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -1240,7 +1248,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
|||||||
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
|
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
|
||||||
|
|
||||||
tracing::info!("Waiting for {process_name} to gracefully shutdown");
|
tracing::info!("Waiting for {process_name} to gracefully shutdown");
|
||||||
|
|
||||||
while terminate_time.elapsed() < timeout {
|
while terminate_time.elapsed() < timeout {
|
||||||
if let Some(status) = process.try_wait()? {
|
if let Some(status) = process.try_wait()? {
|
||||||
tracing::info!("{process_name} terminated");
|
tracing::info!("{process_name} terminated");
|
||||||
@ -1248,7 +1255,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
|||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("Killing {process_name}");
|
tracing::info!("Killing {process_name}");
|
||||||
|
|
||||||
process.kill()?;
|
process.kill()?;
|
||||||
@ -1293,7 +1299,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
tracing::info!("{}", env_runtime);
|
tracing::info!("{}", env_runtime);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("{:?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
@ -1326,7 +1332,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||||
if max_position_embeddings > max_default {
|
if max_position_embeddings > max_default {
|
||||||
let max = max_position_embeddings;
|
let max = max_position_embeddings;
|
||||||
|
if args.max_input_tokens.is_none()
|
||||||
|
&& args.max_total_tokens.is_none()
|
||||||
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
|
{
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
|
}
|
||||||
max_default
|
max_default
|
||||||
} else {
|
} else {
|
||||||
max_position_embeddings
|
max_position_embeddings
|
||||||
@ -1398,7 +1409,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
@ -1493,13 +1504,13 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
num_shard,
|
num_shard,
|
||||||
&args,
|
&args,
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
|
max_total_tokens,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
shutdown_sender,
|
shutdown_sender,
|
||||||
&status_receiver,
|
&status_receiver,
|
||||||
status_sender,
|
status_sender,
|
||||||
running.clone(),
|
running.clone(),
|
||||||
max_total_tokens,
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// We might have received a termination signal
|
// We might have received a termination signal
|
||||||
|
@ -1,71 +1,94 @@
|
|||||||
import { check, randomSeed } from 'k6';
|
import { check } from 'k6';
|
||||||
|
import { scenario } from 'k6/execution';
|
||||||
import http from 'k6/http';
|
import http from 'k6/http';
|
||||||
import { Trend, Counter } from 'k6/metrics';
|
import { Trend, Counter } from 'k6/metrics';
|
||||||
import { randomItem } from 'https://jslib.k6.io/k6-utils/1.2.0/index.js';
|
|
||||||
|
|
||||||
const seed = 0;
|
const host = __ENV.HOST;
|
||||||
|
const model_id = __ENV.MODEL_ID;
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const timePerToken = new Trend('time_per_token', true);
|
const timePerToken = new Trend('time_per_token', true);
|
||||||
const tokens = new Counter('tokens');
|
const tokens = new Counter('tokens');
|
||||||
const new_tokens = new Counter('new_tokens');
|
const new_tokens = new Counter('new_tokens');
|
||||||
const input_tokens = new Counter('input_tokens');
|
const input_tokens = new Counter('input_tokens');
|
||||||
|
const max_new_tokens = 50;
|
||||||
|
|
||||||
randomSeed(seed);
|
|
||||||
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
||||||
const shareGPT = JSON.parse(open("small.json"))
|
const shareGPT = JSON.parse(open("small.json"))
|
||||||
|
|
||||||
|
|
||||||
export function get_options(reference_latency_ms){
|
export function get_options() {
|
||||||
return {
|
return {
|
||||||
thresholds: {
|
thresholds: {
|
||||||
http_req_failed: ['rate==0'],
|
http_req_failed: ['rate==0'],
|
||||||
time_per_token: [{
|
// time_per_token: [{
|
||||||
threshold: `p(50)<${5 * reference_latency_ms}`,
|
// threshold: `p(50)<${5 * reference_latency_ms}`,
|
||||||
abortOnFail: true,
|
// abortOnFail: true,
|
||||||
delayAbortEval: '10s'
|
// delayAbortEval: '10s'
|
||||||
}],
|
// }],
|
||||||
},
|
},
|
||||||
scenarios: {
|
scenarios: {
|
||||||
|
// single_user: {
|
||||||
|
// executor: 'constant-arrival-rate',
|
||||||
|
// duration: '60s',
|
||||||
|
// preAllocatedVUs: 1,
|
||||||
|
// rate: 20,
|
||||||
|
// timeUnit: '1s',
|
||||||
|
// },
|
||||||
load_test: {
|
load_test: {
|
||||||
executor: 'constant-arrival-rate',
|
executor: 'constant-arrival-rate',
|
||||||
duration: '60s',
|
duration: '60s',
|
||||||
preAllocatedVUs: 10,
|
preAllocatedVUs: 100,
|
||||||
rate: 10,
|
rate: 1,
|
||||||
timeUnit: '1s',
|
timeUnit: '1s',
|
||||||
},
|
},
|
||||||
|
// breakpoint: {
|
||||||
|
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
||||||
|
// preAllocatedVUs: 300,
|
||||||
|
// stages: [
|
||||||
|
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
||||||
|
// ],
|
||||||
|
// },
|
||||||
|
// throughput: {
|
||||||
|
// executor: 'shared-iterations',
|
||||||
|
// vus: 100,
|
||||||
|
// iterations: 200,
|
||||||
|
// maxDuration: '40s',
|
||||||
|
// },
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function generate_payload(gpt, max_new_tokens) {
|
||||||
|
const input = gpt["conversations"][0]["value"];
|
||||||
|
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
|
||||||
|
}
|
||||||
|
|
||||||
export function run(host, generate_payload, max_new_tokens) {
|
export const options = get_options();
|
||||||
const headers = {'Content-Type': 'application/json'};
|
|
||||||
const query = randomItem(shareGPT);
|
export default function run() {
|
||||||
const payload = JSON.stringify(generate_payload(query));
|
const headers = { 'Content-Type': 'application/json' };
|
||||||
const res = http.post(`http://${host}/generate`, payload, {
|
const query = shareGPT[scenario.iterationInTest % shareGPT.length];
|
||||||
|
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
|
||||||
|
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
|
||||||
headers,
|
headers,
|
||||||
});
|
});
|
||||||
if(res.status >= 400 && res.status < 500){
|
if (res.status >= 400 && res.status < 500) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
check(res, {
|
check(res, {
|
||||||
'Post status is 200': (r) => res.status === 200,
|
'Post status is 200': (res) => res.status === 200,
|
||||||
});
|
});
|
||||||
const duration = res.timings.duration;
|
const duration = res.timings.duration;
|
||||||
|
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
const body = res.json();
|
const body = res.json();
|
||||||
const n_tokens = body.details.tokens.length;
|
const completion_tokens = body.usage.completion_tokens;
|
||||||
const latency_ms_per_token = duration / n_tokens;
|
const latency_ms_per_token = duration / completion_tokens;
|
||||||
timePerToken.add(latency_ms_per_token);
|
timePerToken.add(latency_ms_per_token);
|
||||||
const latency_in_s = latency_ms_per_token / 1000;
|
const prompt_tokens = body.usage.prompt_tokens;
|
||||||
const individual_throughput = 1 / latency_in_s;
|
input_tokens.add(prompt_tokens);
|
||||||
const _input_tokens = body.details.prefill.length;
|
new_tokens.add(completion_tokens);
|
||||||
tokens.add(n_tokens + _input_tokens);
|
tokens.add(completion_tokens + prompt_tokens);
|
||||||
input_tokens.add(_input_tokens);
|
|
||||||
new_tokens.add(n_tokens);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
import { get_options, run } from "./common.js";
|
|
||||||
|
|
||||||
const reference_latency_ms = 70;
|
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const max_new_tokens = 50;
|
|
||||||
|
|
||||||
|
|
||||||
function generate_payload(gpt){
|
|
||||||
const input = gpt["conversations"][0]["value"];
|
|
||||||
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const options = get_options(reference_latency_ms);
|
|
||||||
|
|
||||||
export default function(){
|
|
||||||
run(host, generate_payload, max_new_tokens);
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
import { get_options, run } from "./common.js";
|
|
||||||
|
|
||||||
const reference_latency_ms = 22;
|
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const max_new_tokens = 50;
|
|
||||||
|
|
||||||
|
|
||||||
function generate_payload(gpt){
|
|
||||||
const input = gpt["conversations"][0]["value"];
|
|
||||||
return {"prompt": input, "temperature": 0.5, "ignore_eos": true}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const options = get_options(reference_latency_ms);
|
|
||||||
|
|
||||||
export default function(){
|
|
||||||
run(host, generate_payload, max_new_tokens);
|
|
||||||
}
|
|
@ -57,6 +57,31 @@ fn select_best_resolution(
|
|||||||
best_fit.unwrap_or((original_height, original_width))
|
best_fit.unwrap_or((original_height, original_width))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unpadded_features(
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
npatches: usize,
|
||||||
|
num_patch_height: usize,
|
||||||
|
num_patch_width: usize,
|
||||||
|
) -> (usize, usize) {
|
||||||
|
let current_height = npatches * num_patch_height;
|
||||||
|
let current_width = npatches * num_patch_width;
|
||||||
|
|
||||||
|
let aspect_ratio: f64 = width as f64 / height as f64;
|
||||||
|
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||||
|
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||||
|
let new_height = (height * current_width) / width;
|
||||||
|
(new_height, current_width)
|
||||||
|
} else {
|
||||||
|
let new_width = (width * current_height) / height;
|
||||||
|
(current_height, new_width)
|
||||||
|
};
|
||||||
|
|
||||||
|
let unpadded_features = current_height * current_width;
|
||||||
|
let newline_features = current_height;
|
||||||
|
(unpadded_features, newline_features)
|
||||||
|
}
|
||||||
|
|
||||||
impl LlavaNext {
|
impl LlavaNext {
|
||||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||||
let image_size = self.vision_config.image_size;
|
let image_size = self.vision_config.image_size;
|
||||||
@ -65,11 +90,9 @@ impl LlavaNext {
|
|||||||
let npatches = image_size / patch_size;
|
let npatches = image_size / patch_size;
|
||||||
let (num_patch_height, num_patch_width) =
|
let (num_patch_height, num_patch_width) =
|
||||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||||
// Ceil
|
|
||||||
let height_of_patch = (height * npatches + width - 1) / width;
|
let (unpadded_features, newline_features) =
|
||||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
|
||||||
// They are only added after width
|
|
||||||
let newline_features = height_of_patch * num_patch_width;
|
|
||||||
// The base patch covers the entire image
|
// The base patch covers the entire image
|
||||||
let base_features = npatches.pow(2);
|
let base_features = npatches.pow(2);
|
||||||
unpadded_features + newline_features + base_features
|
unpadded_features + newline_features + base_features
|
||||||
@ -84,6 +107,17 @@ pub struct ClipVisionModel {
|
|||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "model_type")]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Idefics2 {}
|
||||||
|
|
||||||
|
impl Idefics2 {
|
||||||
|
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||||
|
320
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@ -92,6 +126,7 @@ pub enum Config {
|
|||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
@ -146,13 +181,17 @@ mod test {
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let slots = config.get_number_of_features(20, 20);
|
||||||
|
assert_eq!(slots, 1176);
|
||||||
let slots = config.get_number_of_features(640, 640);
|
let slots = config.get_number_of_features(640, 640);
|
||||||
assert_eq!(slots, 2928);
|
assert_eq!(slots, 2928);
|
||||||
let slots = config.get_number_of_features(480, 640);
|
let slots = config.get_number_of_features(480, 640);
|
||||||
assert_eq!(slots, 2340);
|
assert_eq!(slots, 2340);
|
||||||
let slots = config.get_number_of_features(899, 1024);
|
let slots = config.get_number_of_features(899, 1024);
|
||||||
assert_eq!(slots, 2732);
|
assert_eq!(slots, 2634);
|
||||||
let slots = config.get_number_of_features(1024, 899);
|
let slots = config.get_number_of_features(1024, 899);
|
||||||
assert_eq!(slots, 3320);
|
assert_eq!(slots, 2640);
|
||||||
|
let slots = config.get_number_of_features(1067, 1600);
|
||||||
|
assert_eq!(slots, 2144);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).ok()?;
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,6 +116,7 @@ mod token_serde {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Value::Null => Ok(None),
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,9 +169,12 @@ pub struct Info {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
|
/// Generate best_of sequences and return the one if the highest token logprobs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
pub best_of: Option<usize>,
|
pub best_of: Option<usize>,
|
||||||
|
|
||||||
|
/// The value used to module the logits distribution.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.5
|
example = 0.5
|
||||||
)]
|
)]
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
/// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 1.03
|
example = 1.03
|
||||||
)]
|
)]
|
||||||
pub repetition_penalty: Option<f32>,
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// The parameter for frequency penalty. 1.0 means no penalty
|
||||||
|
/// Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = -2.0,
|
exclusive_minimum = -2.0,
|
||||||
@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.1
|
example = 0.1
|
||||||
)]
|
)]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Top-p value for nucleus sampling.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.95
|
example = 0.95
|
||||||
)]
|
)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Typical Decoding mass
|
||||||
|
/// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.95
|
example = 0.95
|
||||||
)]
|
)]
|
||||||
pub typical_p: Option<f32>,
|
pub typical_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Activate logits sampling.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
|
|
||||||
|
/// Maximum number of tokens to generate.
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
#[schema(nullable = true, default = "100", example = "20")]
|
#[schema(nullable = true, default = "100", example = "20")]
|
||||||
pub max_new_tokens: Option<u32>,
|
pub max_new_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Whether to prepend the prompt to the generated text
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = false)]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
|
|
||||||
|
/// Stop generating tokens if a member of `stop` is generated.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
|
|
||||||
|
/// Truncate inputs tokens to the given size.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub truncate: Option<usize>,
|
pub truncate: Option<usize>,
|
||||||
|
|
||||||
|
/// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
|
|
||||||
|
/// Whether to return generation details.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
|
|
||||||
|
/// Whether to return decoder input token logprobs and ids.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false")]
|
#[schema(default = "false")]
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
|
|
||||||
|
/// Random sampling seed.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0,
|
exclusive_minimum = 0,
|
||||||
@ -248,9 +284,13 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = "null"
|
example = "null"
|
||||||
)]
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// The number of highest probability vocabulary tokens to keep for top-n-filtering.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Grammar constraints for the generation.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub grammar: Option<GrammarType>,
|
pub grammar: Option<GrammarType>,
|
||||||
@ -549,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
// TODO Modify this to a true enum.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
@ -583,6 +625,31 @@ impl ChatCompletionChunk {
|
|||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let delta = match (delta, tool_calls) {
|
||||||
|
(Some(delta), _) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(delta),
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: Some(DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tool_calls[0].to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
(None, None) => ChatCompletionDelta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
@ -591,19 +658,7 @@ impl ChatCompletionChunk {
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionChoice {
|
choices: vec![ChatCompletionChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: ChatCompletionDelta {
|
delta,
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: delta,
|
|
||||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: None,
|
|
||||||
arguments: tc[0].to_string(),
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
@ -829,12 +884,75 @@ pub(crate) struct ToolCall {
|
|||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
|
pub(crate) struct Text {
|
||||||
|
#[serde(default)]
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
|
pub(crate) struct ImageUrl {
|
||||||
|
#[serde(default)]
|
||||||
|
pub url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
|
pub(crate) struct Content {
|
||||||
|
pub r#type: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_url: Option<ImageUrl>,
|
||||||
|
}
|
||||||
|
|
||||||
|
mod message_content_serde {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(Some(s)),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let results: Result<Vec<String>, _> = arr
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| {
|
||||||
|
let content: Content =
|
||||||
|
serde_json::from_value(v).map_err(de::Error::custom)?;
|
||||||
|
match content.r#type.as_str() {
|
||||||
|
"text" => Ok(content.text.unwrap_or_default()),
|
||||||
|
"image_url" => {
|
||||||
|
if let Some(url) = content.image_url {
|
||||||
|
Ok(format!("", url.url))
|
||||||
|
} else {
|
||||||
|
Ok(String::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(de::Error::custom("invalid content type")),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
results.map(|strings| Some(strings.join("")))
|
||||||
|
}
|
||||||
|
Value::Null => Ok(None),
|
||||||
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
|
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
@ -14,7 +14,7 @@ use std::env;
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||||
@ -184,7 +184,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
@ -203,53 +202,89 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
// Initialize API if needed
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
let api = if use_api {
|
let api = if use_api {
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = Cache::default();
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
tracing::info!("Using the Hugging Face API");
|
tracing::info!("Using the Hugging Face API");
|
||||||
match api_builder().build() {
|
match api_builder().build() {
|
||||||
Ok(api) => Some(api),
|
Ok(api) => Type::Api(api),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
None
|
Type::None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
Type::None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
|
let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
|
||||||
.ok()
|
.ok()
|
||||||
.map_or(false, |value| value.to_lowercase() == "true");
|
.map_or(false, |value| value.to_lowercase() == "true");
|
||||||
let (tokenizer, model_info, config) = if local_model {
|
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||||
let tokenizer = if skip_tokenizer_in_tgi {
|
Type::None => (
|
||||||
None
|
Some(local_path.join("tokenizer.json")),
|
||||||
} else {
|
Some(local_path.join("config.json")),
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
};
|
None,
|
||||||
let model_info = HubModelInfo {
|
),
|
||||||
model_id: tokenizer_name.to_string(),
|
Type::Api(api) => {
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
};
|
|
||||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| serde_json::from_str(c).ok());
|
|
||||||
|
|
||||||
(tokenizer, model_info, config)
|
|
||||||
} else if let Some(api) = api.clone() {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
tokenizer_name.to_string(),
|
tokenizer_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
};
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
|
||||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
(
|
||||||
|
repo.get("tokenizer.json"),
|
||||||
|
repo.get("config.json"),
|
||||||
|
repo.get("tokenizer_config.json"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tokenizer: Option<Tokenizer> = if skip_tokenizer_in_tgi {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
||||||
|
};
|
||||||
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
.ok()
|
.ok()
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@ -261,58 +296,25 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
config.ok()
|
config.ok()
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
model_id: tokenizer_name.to_string(),
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
(tokenizer, model_info, config)
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
} else {
|
} else {
|
||||||
// No API and no local model
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"No local model found and no revision specified".to_string(),
|
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
|
||||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
|
||||||
tracing::info!("Using local tokenizer config from user specified path");
|
|
||||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
|
||||||
} else if local_model {
|
|
||||||
tracing::info!("Using local tokenizer config");
|
|
||||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
|
||||||
} else {
|
|
||||||
match api {
|
|
||||||
Some(api) => {
|
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.unwrap_or("main".to_string()),
|
|
||||||
);
|
|
||||||
get_tokenizer_config(&api.repo(repo))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!(
|
|
||||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
|
||||||
);
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
}
|
});
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
@ -509,7 +511,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// get base tokenizer
|
/// get base tokenizer
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
// Open the file in read-only mode with buffer.
|
||||||
@ -526,8 +528,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
|||||||
"main".to_string(),
|
"main".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
api_base_repo.get("tokenizer.json").await.ok()
|
||||||
Tokenizer::from_file(tokenizer_filename).ok()
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -1002,6 +1002,7 @@ async fn chat_completions(
|
|||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
@ -1010,6 +1011,11 @@ async fn chat_completions(
|
|||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
@ -1056,13 +1062,13 @@ async fn chat_completions(
|
|||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop,
|
stop,
|
||||||
@ -1099,7 +1105,13 @@ async fn chat_completions(
|
|||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
(Some(stream_token.token.text), None)
|
let content = if !stream_token.token.special {
|
||||||
|
Some(stream_token.token.text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
(content, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
|
@ -565,7 +565,57 @@ fn prepare_input(
|
|||||||
inputs = modified_inputs;
|
inputs = modified_inputs;
|
||||||
tokenizer_query
|
tokenizer_query
|
||||||
}
|
}
|
||||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
Some(Config::Idefics2(config)) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
|
Some(Config::Idefics) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = 1;
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
_ => inputs.clone(),
|
_ => inputs.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
vllm-cuda:
|
vllm-cuda:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
git clone https://github.com/Narsil/vllm.git vllm
|
||||||
|
|
||||||
build-vllm-cuda: vllm-cuda
|
build-vllm-cuda: vllm-cuda
|
||||||
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
|
cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm-cuda: build-vllm-cuda
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
441
server/poetry.lock
generated
441
server/poetry.lock
generated
@ -194,13 +194,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2024.2.2"
|
version = "2024.6.2"
|
||||||
description = "Python package for providing Mozilla's CA Bundle."
|
description = "Python package for providing Mozilla's CA Bundle."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
files = [
|
files = [
|
||||||
{file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
|
{file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"},
|
||||||
{file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
|
{file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -357,13 +357,13 @@ cron = ["capturer (>=2.4)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "datasets"
|
name = "datasets"
|
||||||
version = "2.19.1"
|
version = "2.19.2"
|
||||||
description = "HuggingFace community-driven open-source library of datasets"
|
description = "HuggingFace community-driven open-source library of datasets"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"},
|
{file = "datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e"},
|
||||||
{file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"},
|
{file = "datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -379,7 +379,7 @@ pandas = "*"
|
|||||||
pyarrow = ">=12.0.0"
|
pyarrow = ">=12.0.0"
|
||||||
pyarrow-hotfix = "*"
|
pyarrow-hotfix = "*"
|
||||||
pyyaml = ">=5.1"
|
pyyaml = ">=5.1"
|
||||||
requests = ">=2.19.0"
|
requests = ">=2.32.1"
|
||||||
tqdm = ">=4.62.1"
|
tqdm = ">=4.62.1"
|
||||||
xxhash = "*"
|
xxhash = "*"
|
||||||
|
|
||||||
@ -387,7 +387,7 @@ xxhash = "*"
|
|||||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
dev = ["Pillow (>=9.4.0)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||||
@ -395,9 +395,9 @@ quality = ["ruff (>=0.3.0)"]
|
|||||||
s3 = ["s3fs"]
|
s3 = ["s3fs"]
|
||||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
tests = ["Pillow (>=9.4.0)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
vision = ["Pillow (>=6.2.1)"]
|
vision = ["Pillow (>=9.4.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deprecated"
|
name = "deprecated"
|
||||||
@ -628,17 +628,17 @@ tqdm = ["tqdm"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "googleapis-common-protos"
|
name = "googleapis-common-protos"
|
||||||
version = "1.63.0"
|
version = "1.63.1"
|
||||||
description = "Common protobufs used in Google APIs"
|
description = "Common protobufs used in Google APIs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"},
|
{file = "googleapis-common-protos-1.63.1.tar.gz", hash = "sha256:c6442f7a0a6b2a80369457d79e6672bb7dcbaab88e0848302497e3ec80780a6a"},
|
||||||
{file = "googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632"},
|
{file = "googleapis_common_protos-1.63.1-py2.py3-none-any.whl", hash = "sha256:0e1c2cdfcbc354b76e4a211a35ea35d6926a835cba1377073c4861db904a1877"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
|
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
|
||||||
@ -662,61 +662,61 @@ testing = ["protobuf (>=4.21.9)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio"
|
name = "grpcio"
|
||||||
version = "1.64.0"
|
version = "1.64.1"
|
||||||
description = "HTTP/2-based RPC framework"
|
description = "HTTP/2-based RPC framework"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3b09c3d9de95461214a11d82cc0e6a46a6f4e1f91834b50782f932895215e5db"},
|
{file = "grpcio-1.64.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:55697ecec192bc3f2f3cc13a295ab670f51de29884ca9ae6cd6247df55df2502"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:7e013428ab472892830287dd082b7d129f4d8afef49227a28223a77337555eaa"},
|
{file = "grpcio-1.64.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3b64ae304c175671efdaa7ec9ae2cc36996b681eb63ca39c464958396697daff"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:02cc9cc3f816d30f7993d0d408043b4a7d6a02346d251694d8ab1f78cc723e7e"},
|
{file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bac71b4b28bc9af61efcdc7630b166440bbfbaa80940c9a697271b5e1dabbc61"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f5de082d936e0208ce8db9095821361dfa97af8767a6607ae71425ac8ace15c"},
|
{file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c024ffc22d6dc59000faf8ad781696d81e8e38f4078cb0f2630b4a3cf231a90"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b7bf346391dffa182fba42506adf3a84f4a718a05e445b37824136047686a1"},
|
{file = "grpcio-1.64.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cd5c1325f6808b8ae31657d281aadb2a51ac11ab081ae335f4f7fc44c1721d"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b2cbdfba18408389a1371f8c2af1659119e1831e5ed24c240cae9e27b4abc38d"},
|
{file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0a2813093ddb27418a4c99f9b1c223fab0b053157176a64cc9db0f4557b69bd9"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:aca4f15427d2df592e0c8f3d38847e25135e4092d7f70f02452c0e90d6a02d6d"},
|
{file = "grpcio-1.64.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2981c7365a9353f9b5c864595c510c983251b1ab403e05b1ccc70a3d9541a73b"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-win32.whl", hash = "sha256:7c1f5b2298244472bcda49b599be04579f26425af0fd80d3f2eb5fd8bc84d106"},
|
{file = "grpcio-1.64.1-cp310-cp310-win32.whl", hash = "sha256:1262402af5a511c245c3ae918167eca57342c72320dffae5d9b51840c4b2f86d"},
|
||||||
{file = "grpcio-1.64.0-cp310-cp310-win_amd64.whl", hash = "sha256:73f84f9e5985a532e47880b3924867de16fa1aa513fff9b26106220c253c70c5"},
|
{file = "grpcio-1.64.1-cp310-cp310-win_amd64.whl", hash = "sha256:19264fc964576ddb065368cae953f8d0514ecc6cb3da8903766d9fb9d4554c33"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2a18090371d138a57714ee9bffd6c9c9cb2e02ce42c681aac093ae1e7189ed21"},
|
{file = "grpcio-1.64.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:58b1041e7c870bb30ee41d3090cbd6f0851f30ae4eb68228955d973d3efa2e61"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:59c68df3a934a586c3473d15956d23a618b8f05b5e7a3a904d40300e9c69cbf0"},
|
{file = "grpcio-1.64.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bbc5b1d78a7822b0a84c6f8917faa986c1a744e65d762ef6d8be9d75677af2ca"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b52e1ec7185512103dd47d41cf34ea78e7a7361ba460187ddd2416b480e0938c"},
|
{file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5841dd1f284bd1b3d8a6eca3a7f062b06f1eec09b184397e1d1d43447e89a7ae"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d598b5d5e2c9115d7fb7e2cb5508d14286af506a75950762aa1372d60e41851"},
|
{file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8caee47e970b92b3dd948371230fcceb80d3f2277b3bf7fbd7c0564e7d39068e"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01615bbcae6875eee8091e6b9414072f4e4b00d8b7e141f89635bdae7cf784e5"},
|
{file = "grpcio-1.64.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73819689c169417a4f978e562d24f2def2be75739c4bed1992435d007819da1b"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0b2dfe6dcace264807d9123d483d4c43274e3f8c39f90ff51de538245d7a4145"},
|
{file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6503b64c8b2dfad299749cad1b595c650c91e5b2c8a1b775380fcf8d2cbba1e9"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7f17572dc9acd5e6dfd3014d10c0b533e9f79cd9517fc10b0225746f4c24b58e"},
|
{file = "grpcio-1.64.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1de403fc1305fd96cfa75e83be3dee8538f2413a6b1685b8452301c7ba33c294"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-win32.whl", hash = "sha256:6ec5ed15b4ffe56e2c6bc76af45e6b591c9be0224b3fb090adfb205c9012367d"},
|
{file = "grpcio-1.64.1-cp311-cp311-win32.whl", hash = "sha256:d4d29cc612e1332237877dfa7fe687157973aab1d63bd0f84cf06692f04c0367"},
|
||||||
{file = "grpcio-1.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:597191370951b477b7a1441e1aaa5cacebeb46a3b0bd240ec3bb2f28298c7553"},
|
{file = "grpcio-1.64.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e56462b05a6f860b72f0fa50dca06d5b26543a4e88d0396259a07dc30f4e5aa"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"},
|
{file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"},
|
{file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"},
|
{file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"},
|
{file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"},
|
{file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"},
|
{file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"},
|
{file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"},
|
{file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"},
|
||||||
{file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"},
|
{file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:c46fb6bfca17bfc49f011eb53416e61472fa96caa0979b4329176bdd38cbbf2a"},
|
{file = "grpcio-1.64.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:a011ac6c03cfe162ff2b727bcb530567826cec85eb8d4ad2bfb4bd023287a52d"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3d2004e85cf5213995d09408501f82c8534700d2babeb81dfdba2a3bff0bb396"},
|
{file = "grpcio-1.64.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4d6dab6124225496010bd22690f2d9bd35c7cbb267b3f14e7a3eb05c911325d4"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6d5541eb460d73a07418524fb64dcfe0adfbcd32e2dac0f8f90ce5b9dd6c046c"},
|
{file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:a5e771d0252e871ce194d0fdcafd13971f1aae0ddacc5f25615030d5df55c3a2"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f279ad72dd7d64412e10f2443f9f34872a938c67387863c4cd2fb837f53e7d2"},
|
{file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c3c1b90ab93fed424e454e93c0ed0b9d552bdf1b0929712b094f5ecfe7a23ad"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85fda90b81da25993aa47fae66cae747b921f8f6777550895fb62375b776a231"},
|
{file = "grpcio-1.64.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20405cb8b13fd779135df23fabadc53b86522d0f1cba8cca0e87968587f50650"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a053584079b793a54bece4a7d1d1b5c0645bdbee729215cd433703dc2532f72b"},
|
{file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0cc79c982ccb2feec8aad0e8fb0d168bcbca85bc77b080d0d3c5f2f15c24ea8f"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:579dd9fb11bc73f0de061cab5f8b2def21480fd99eb3743ed041ad6a1913ee2f"},
|
{file = "grpcio-1.64.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a3a035c37ce7565b8f4f35ff683a4db34d24e53dc487e47438e434eb3f701b2a"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-win32.whl", hash = "sha256:23b6887bb21d77649d022fa1859e05853fdc2e60682fd86c3db652a555a282e0"},
|
{file = "grpcio-1.64.1-cp38-cp38-win32.whl", hash = "sha256:1257b76748612aca0f89beec7fa0615727fd6f2a1ad580a9638816a4b2eb18fd"},
|
||||||
{file = "grpcio-1.64.0-cp38-cp38-win_amd64.whl", hash = "sha256:753cb58683ba0c545306f4e17dabf468d29cb6f6b11832e1e432160bb3f8403c"},
|
{file = "grpcio-1.64.1-cp38-cp38-win_amd64.whl", hash = "sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:2186d76a7e383e1466e0ea2b0febc343ffeae13928c63c6ec6826533c2d69590"},
|
{file = "grpcio-1.64.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:75dbbf415026d2862192fe1b28d71f209e2fd87079d98470db90bebe57b33179"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0f30596cdcbed3c98024fb4f1d91745146385b3f9fd10c9f2270cbfe2ed7ed91"},
|
{file = "grpcio-1.64.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e3d9f8d1221baa0ced7ec7322a981e28deb23749c76eeeb3d33e18b72935ab62"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:d9171f025a196f5bcfec7e8e7ffb7c3535f7d60aecd3503f9e250296c7cfc150"},
|
{file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:5f8b75f64d5d324c565b263c67dbe4f0af595635bbdd93bb1a88189fc62ed2e5"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf4c8daed18ae2be2f1fc7d613a76ee2a2e28fdf2412d5c128be23144d28283d"},
|
{file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c84ad903d0d94311a2b7eea608da163dace97c5fe9412ea311e72c3684925602"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3550493ac1d23198d46dc9c9b24b411cef613798dc31160c7138568ec26bc9b4"},
|
{file = "grpcio-1.64.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940e3ec884520155f68a3b712d045e077d61c520a195d1a5932c531f11883489"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3161a8f8bb38077a6470508c1a7301cd54301c53b8a34bb83e3c9764874ecabd"},
|
{file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e8fabe2cc57a369638ab1ad8e6043721014fdf9a13baa7c0e35995d3a4a7618"},
|
{file = "grpcio-1.64.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac15b6c2c80a4d1338b04d42a02d376a53395ddf0ec9ab157cbaf44191f3ffdd"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-win32.whl", hash = "sha256:31890b24d47b62cc27da49a462efe3d02f3c120edb0e6c46dcc0025506acf004"},
|
{file = "grpcio-1.64.1-cp39-cp39-win32.whl", hash = "sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040"},
|
||||||
{file = "grpcio-1.64.0-cp39-cp39-win_amd64.whl", hash = "sha256:5a56797dea8c02e7d3a85dfea879f286175cf4d14fbd9ab3ef2477277b927baa"},
|
{file = "grpcio-1.64.1-cp39-cp39-win_amd64.whl", hash = "sha256:ed6091fa0adcc7e4ff944090cf203a52da35c37a130efa564ded02b7aff63bcd"},
|
||||||
{file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"},
|
{file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
protobuf = ["grpcio-tools (>=1.64.0)"]
|
protobuf = ["grpcio-tools (>=1.64.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio-reflection"
|
name = "grpcio-reflection"
|
||||||
@ -883,13 +883,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.23.2"
|
version = "0.23.3"
|
||||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
|
{file = "huggingface_hub-0.23.3-py3-none-any.whl", hash = "sha256:22222c41223f1b7c209ae5511d2d82907325a0e3cdbce5f66949d43c598ff3bc"},
|
||||||
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
|
{file = "huggingface_hub-0.23.3.tar.gz", hash = "sha256:1a1118a0b3dea3bab6c325d71be16f5ffe441d32f3ac7c348d6875911b694b5b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1756,13 +1756,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "optimum"
|
name = "optimum"
|
||||||
version = "1.19.2"
|
version = "1.20.0"
|
||||||
description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality."
|
description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7.0"
|
python-versions = ">=3.7.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "optimum-1.19.2-py3-none-any.whl", hash = "sha256:66f0fafda050ee6671bab6899852b9bf95afac766d99aa54a40699c7dee598bf"},
|
{file = "optimum-1.20.0-py3-none-any.whl", hash = "sha256:0c0d0746043c95e22cf3586946d7408d353f10c0486f1c7d2d11084a5cfc0ede"},
|
||||||
{file = "optimum-1.19.2.tar.gz", hash = "sha256:fc22e07f084d867bd9bce32fd0d737f7c4863514ea5d90c7acccf5dcfe5f2296"},
|
{file = "optimum-1.20.0.tar.gz", hash = "sha256:b64c7536fe738db9b56605105efe72006401ad2aa00cb499ae407f2e06f3043b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1773,7 +1773,7 @@ numpy = "*"
|
|||||||
packaging = "*"
|
packaging = "*"
|
||||||
sympy = "*"
|
sympy = "*"
|
||||||
torch = ">=1.11"
|
torch = ">=1.11"
|
||||||
transformers = {version = ">=4.26.0,<4.41.0", extras = ["sentencepiece"]}
|
transformers = {version = ">=4.26.0,<4.42.0", extras = ["sentencepiece"]}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
amd = ["optimum-amd"]
|
amd = ["optimum-amd"]
|
||||||
@ -1787,14 +1787,14 @@ exporters-tf = ["h5py", "numpy (<1.24.0)", "onnx", "onnxruntime", "tensorflow (>
|
|||||||
furiosa = ["optimum-furiosa"]
|
furiosa = ["optimum-furiosa"]
|
||||||
graphcore = ["optimum-graphcore"]
|
graphcore = ["optimum-graphcore"]
|
||||||
habana = ["optimum-habana", "transformers (>=4.38.0,<4.39.0)"]
|
habana = ["optimum-habana", "transformers (>=4.38.0,<4.39.0)"]
|
||||||
intel = ["optimum-intel (>=1.15.0)"]
|
intel = ["optimum-intel (>=1.16.0)"]
|
||||||
neural-compressor = ["optimum-intel[neural-compressor] (>=1.15.0)"]
|
neural-compressor = ["optimum-intel[neural-compressor] (>=1.16.0)"]
|
||||||
neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (==4.36.2)"]
|
neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
|
||||||
neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (==4.36.2)"]
|
neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
|
||||||
nncf = ["optimum-intel[nncf] (>=1.15.0)"]
|
nncf = ["optimum-intel[nncf] (>=1.16.0)"]
|
||||||
onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)"]
|
onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)"]
|
||||||
onnxruntime-gpu = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)"]
|
onnxruntime-gpu = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)"]
|
||||||
openvino = ["optimum-intel[openvino] (>=1.15.0)"]
|
openvino = ["optimum-intel[openvino] (>=1.16.0)"]
|
||||||
quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"]
|
quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"]
|
||||||
tests = ["Pillow", "accelerate", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"]
|
tests = ["Pillow", "accelerate", "diffusers (>=0.17.0)", "einops", "invisible-watermark", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "timm", "torchaudio", "torchvision"]
|
||||||
|
|
||||||
@ -1856,13 +1856,13 @@ test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "packaging"
|
name = "packaging"
|
||||||
version = "24.0"
|
version = "24.1"
|
||||||
description = "Core utilities for Python packages"
|
description = "Core utilities for Python packages"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
|
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
|
||||||
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
|
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2068,6 +2068,20 @@ files = [
|
|||||||
dev = ["pre-commit", "tox"]
|
dev = ["pre-commit", "tox"]
|
||||||
testing = ["pytest", "pytest-benchmark"]
|
testing = ["pytest", "pytest-benchmark"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "prometheus-client"
|
||||||
|
version = "0.20.0"
|
||||||
|
description = "Python client for the Prometheus monitoring system."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
|
||||||
|
{file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
twisted = ["twisted"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "protobuf"
|
name = "protobuf"
|
||||||
version = "3.20.3"
|
version = "3.20.3"
|
||||||
@ -2127,6 +2141,17 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "py-cpuinfo"
|
||||||
|
version = "9.0.0"
|
||||||
|
description = "Get CPU info with pure Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"},
|
||||||
|
{file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyarrow"
|
name = "pyarrow"
|
||||||
version = "16.1.0"
|
version = "16.1.0"
|
||||||
@ -2188,18 +2213,18 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.7.1"
|
version = "2.7.3"
|
||||||
description = "Data validation using Python type hints"
|
description = "Data validation using Python type hints"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pydantic-2.7.1-py3-none-any.whl", hash = "sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5"},
|
{file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"},
|
||||||
{file = "pydantic-2.7.1.tar.gz", hash = "sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc"},
|
{file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
annotated-types = ">=0.4.0"
|
annotated-types = ">=0.4.0"
|
||||||
pydantic-core = "2.18.2"
|
pydantic-core = "2.18.4"
|
||||||
typing-extensions = ">=4.6.1"
|
typing-extensions = ">=4.6.1"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@ -2207,90 +2232,90 @@ email = ["email-validator (>=2.0.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic-core"
|
name = "pydantic-core"
|
||||||
version = "2.18.2"
|
version = "2.18.4"
|
||||||
description = "Core functionality for Pydantic validation and serialization"
|
description = "Core functionality for Pydantic validation and serialization"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9e08e867b306f525802df7cd16c44ff5ebbe747ff0ca6cf3fde7f36c05a59a81"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:f76d0ad001edd426b92233d45c746fd08f467d56100fd8f30e9ace4b005266e4"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f0a21cbaa69900cbe1a2e7cad2aa74ac3cf21b10c3efb0fa0b80305274c0e8a2"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:59ff3e89f4eaf14050c8022011862df275b552caef8082e37b542b066ce1ff26"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0680b1f1f11fda801397de52c36ce38ef1c1dc841a0927a94f226dea29c3ae3d"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a55b5b16c839df1070bc113c1f7f94a0af4433fcfa1b41799ce7606e5c79ce0a"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:95b9d5e72481d3780ba3442eac863eae92ae43a5f3adb5b4d0a1de89d42bb250"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4d0dcc59664fcb8974b356fe0a18a672d6d7cf9f54746c05f43275fc48636851"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fcf5cd9c4b655ad666ca332b9a081112cd7a58a8b5a6ca7a3104bc950f2038"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8951eee36c57cd128f779e641e21eb40bc5073eb28b2d23f33eb0ef14ffb3f5d"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b5155ff768083cb1d62f3e143b49a8a3432e6789a3abee8acd005c3c7af1c74"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4701b19f7e3a06ea655513f7938de6f108123bf7c86bbebb1196eb9bd35cf724"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:553ef617b6836fc7e4df130bb851e32fe357ce36336d897fd6646d6058d980af"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e00a3f196329e08e43d99b79b286d60ce46bed10f2280d25a1718399457e06be"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89ed9eb7d616ef5714e5590e6cf7f23b02d0d539767d33561e3675d6f9e3857"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:97736815b9cc893b2b7f663628e63f436018b75f44854c8027040e05230eeddb"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:75f7e9488238e920ab6204399ded280dc4c307d034f3924cd7f90a38b1829563"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6891a2ae0e8692679c07728819b6e2b822fb30ca7445f67bbf6509b25a96332c"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ef26c9e94a8c04a1b2924149a9cb081836913818e55681722d7f29af88fe7b38"},
|
{file = "pydantic_core-2.18.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bc4ff9805858bd54d1a20efff925ccd89c9d2e7cf4986144b30802bf78091c3e"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-none-win32.whl", hash = "sha256:182245ff6b0039e82b6bb585ed55a64d7c81c560715d1bad0cbad6dfa07b4027"},
|
{file = "pydantic_core-2.18.4-cp310-none-win32.whl", hash = "sha256:1b4de2e51bbcb61fdebd0ab86ef28062704f62c82bbf4addc4e37fa4b00b7cbc"},
|
||||||
{file = "pydantic_core-2.18.2-cp310-none-win_amd64.whl", hash = "sha256:e23ec367a948b6d812301afc1b13f8094ab7b2c280af66ef450efc357d2ae543"},
|
{file = "pydantic_core-2.18.4-cp310-none-win_amd64.whl", hash = "sha256:6a750aec7bf431517a9fd78cb93c97b9b0c496090fee84a47a0d23668976b4b0"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:219da3f096d50a157f33645a1cf31c0ad1fe829a92181dd1311022f986e5fbe3"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:942ba11e7dfb66dc70f9ae66b33452f51ac7bb90676da39a7345e99ffb55402d"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc1cfd88a64e012b74e94cd00bbe0f9c6df57049c97f02bb07d39e9c852e19a4"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b2ebef0e0b4454320274f5e83a41844c63438fdc874ea40a8b5b4ecb7693f1c4"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b7133a6e6aeb8df37d6f413f7705a37ab4031597f64ab56384c94d98fa0e90"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a642295cd0c8df1b86fc3dced1d067874c353a188dc8e0f744626d49e9aa51c4"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:224c421235f6102e8737032483f43c1a8cfb1d2f45740c44166219599358c2cd"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f09baa656c904807e832cf9cce799c6460c450c4ad80803517032da0cd062e2"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b14d82cdb934e99dda6d9d60dc84a24379820176cc4a0d123f88df319ae9c150"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98906207f29bc2c459ff64fa007afd10a8c8ac080f7e4d5beff4c97086a3dabd"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2728b01246a3bba6de144f9e3115b532ee44bd6cf39795194fb75491824a1413"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19894b95aacfa98e7cb093cd7881a0c76f55731efad31073db4521e2b6ff5b7d"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:470b94480bb5ee929f5acba6995251ada5e059a5ef3e0dfc63cca287283ebfa6"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fbbdc827fe5e42e4d196c746b890b3d72876bdbf160b0eafe9f0334525119c8"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:997abc4df705d1295a42f95b4eec4950a37ad8ae46d913caeee117b6b198811c"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f85d05aa0918283cf29a30b547b4df2fbb56b45b135f9e35b6807cb28bc47951"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:75250dbc5290e3f1a0f4618db35e51a165186f9034eff158f3d490b3fed9f8a0"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e85637bc8fe81ddb73fda9e56bab24560bdddfa98aa64f87aaa4e4b6730c23d2"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4456f2dca97c425231d7315737d45239b2b51a50dc2b6f0c2bb181fce6207664"},
|
{file = "pydantic_core-2.18.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2f5966897e5461f818e136b8451d0551a2e77259eb0f73a837027b47dc95dab9"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-none-win32.whl", hash = "sha256:269322dcc3d8bdb69f054681edff86276b2ff972447863cf34c8b860f5188e2e"},
|
{file = "pydantic_core-2.18.4-cp311-none-win32.whl", hash = "sha256:44c7486a4228413c317952e9d89598bcdfb06399735e49e0f8df643e1ccd0558"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-none-win_amd64.whl", hash = "sha256:800d60565aec896f25bc3cfa56d2277d52d5182af08162f7954f938c06dc4ee3"},
|
{file = "pydantic_core-2.18.4-cp311-none-win_amd64.whl", hash = "sha256:8a7164fe2005d03c64fd3b85649891cd4953a8de53107940bf272500ba8a788b"},
|
||||||
{file = "pydantic_core-2.18.2-cp311-none-win_arm64.whl", hash = "sha256:1404c69d6a676245199767ba4f633cce5f4ad4181f9d0ccb0577e1f66cf4c46d"},
|
{file = "pydantic_core-2.18.4-cp311-none-win_arm64.whl", hash = "sha256:4e99bc050fe65c450344421017f98298a97cefc18c53bb2f7b3531eb39bc7805"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3"},
|
{file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-none-win32.whl", hash = "sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038"},
|
{file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-none-win_amd64.whl", hash = "sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438"},
|
{file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"},
|
||||||
{file = "pydantic_core-2.18.2-cp312-none-win_arm64.whl", hash = "sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec"},
|
{file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9fdad8e35f278b2c3eb77cbdc5c0a49dada440657bf738d6905ce106dc1de439"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:823be1deb01793da05ecb0484d6c9e20baebb39bd42b5d72636ae9cf8350dbd2"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d90c3265ae107f91a4f279f4d6f6f1d4907ac76c6868b27dc7fb33688cfb347"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ebef0dd9bf9b812bf75bda96743f2a6c5734a02092ae7f721c048d156d5fabae"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:390193c770399861d8df9670fb0d1874f330c79caaca4642332df7c682bf6b91"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae1d6df168efb88d7d522664693607b80b4080be6750c913eefb77e34c12c71a"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:82d5d4d78e4448683cb467897fe24e2b74bb7b973a541ea1dcfec1d3cbce39fb"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9899c94762343f2cc2fc64c13e7cae4c3cc65cdfc87dd810a31654c9b7358cc"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4774f3184d2ef3e14e8693194f661dea5a4d6ca4e3dc8e39786d33a94865cefd"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99457f184ad90235cfe8461c4d70ab7dd2680e28821c29eca00252ba90308c78"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4d938ec0adf5167cb335acb25a4ee69a8107e4984f8fbd2e897021d9e4ca21b"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18f469a3d2a2fdafe99296a87e8a4c37748b5080a26b806a707f25a902c040a8"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e8b1be28239fc64a88a8189d1df7fad8be8c1ae47fcc33e43d4be15f99cc70"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7cdf28938ac6b8b49ae5e92f2735056a7ba99c9b110a474473fd71185c1af5d"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:868649da93e5a3d5eacc2b5b3b9235c98ccdbfd443832f31e075f54419e1b96b"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:938cb21650855054dc54dfd9120a851c974f95450f00683399006aa6e8abb057"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:78363590ef93d5d226ba21a90a03ea89a20738ee5b7da83d771d283fd8a56761"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:44cd83ab6a51da80fb5adbd9560e26018e2ac7826f9626bc06ca3dc074cd198b"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:852e966fbd035a6468fc0a3496589b45e2208ec7ca95c26470a54daed82a0788"},
|
{file = "pydantic_core-2.18.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:972658f4a72d02b8abfa2581d92d59f59897d2e9f7e708fdabe922f9087773af"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-none-win32.whl", hash = "sha256:6a46e22a707e7ad4484ac9ee9f290f9d501df45954184e23fc29408dfad61350"},
|
{file = "pydantic_core-2.18.4-cp38-none-win32.whl", hash = "sha256:1d886dc848e60cb7666f771e406acae54ab279b9f1e4143babc9c2258213daa2"},
|
||||||
{file = "pydantic_core-2.18.2-cp38-none-win_amd64.whl", hash = "sha256:d91cb5ea8b11607cc757675051f61b3d93f15eca3cefb3e6c704a5d6e8440f4e"},
|
{file = "pydantic_core-2.18.4-cp38-none-win_amd64.whl", hash = "sha256:bb4462bd43c2460774914b8525f79b00f8f407c945d50881568f294c1d9b4443"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ae0a8a797a5e56c053610fa7be147993fe50960fa43609ff2a9552b0e07013e8"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:44a688331d4a4e2129140a8118479443bd6f1905231138971372fcde37e43528"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:042473b6280246b1dbf530559246f6842b56119c2926d1e52b631bdc46075f2a"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a2fdd81edd64342c85ac7cf2753ccae0b79bf2dfa063785503cb85a7d3593223"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a388a77e629b9ec814c1b1e6b3b595fe521d2cdc625fcca26fbc2d44c816804"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86110d7e1907ab36691f80b33eb2da87d780f4739ae773e5fc83fb272f88825f"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25add29b8f3b233ae90ccef2d902d0ae0432eb0d45370fe315d1a5cf231004b"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46387e38bd641b3ee5ce247563b60c5ca098da9c56c75c157a05eaa0933ed154"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f459a5ce8434614dfd39bbebf1041952ae01da6bed9855008cb33b875cb024c0"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:123c3cec203e3f5ac7b000bd82235f1a3eced8665b63d18be751f115588fea30"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eff2de745698eb46eeb51193a9f41d67d834d50e424aef27df2fcdee1b153845"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc1803ac5c32ec324c5261c7209e8f8ce88e83254c4e1aebdc8b0a39f9ddb443"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8309f67285bdfe65c372ea3722b7a5642680f3dba538566340a9d36e920b5f0"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53db086f9f6ab2b4061958d9c276d1dbe3690e8dd727d6abf2321d6cce37fa94"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f93a8a2e3938ff656a7c1bc57193b1319960ac015b6e87d76c76bf14fe0244b4"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:abc267fa9837245cc28ea6929f19fa335f3dc330a35d2e45509b6566dc18be23"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:22057013c8c1e272eb8d0eebc796701167d8377441ec894a8fed1af64a0bf399"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a0d829524aaefdebccb869eed855e2d04c21d2d7479b6cada7ace5448416597b"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfeecd1ac6cc1fb2692c3d5110781c965aabd4ec5d32799773ca7b1456ac636b"},
|
{file = "pydantic_core-2.18.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:509daade3b8649f80d4e5ff21aa5673e4ebe58590b25fe42fac5f0f52c6f034a"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-none-win32.whl", hash = "sha256:0d69b4c2f6bb3e130dba60d34c0845ba31b69babdd3f78f7c0c8fae5021a253e"},
|
{file = "pydantic_core-2.18.4-cp39-none-win32.whl", hash = "sha256:ca26a1e73c48cfc54c4a76ff78df3727b9d9f4ccc8dbee4ae3f73306a591676d"},
|
||||||
{file = "pydantic_core-2.18.2-cp39-none-win_amd64.whl", hash = "sha256:d9319e499827271b09b4e411905b24a426b8fb69464dfa1696258f53a3334641"},
|
{file = "pydantic_core-2.18.4-cp39-none-win_amd64.whl", hash = "sha256:c67598100338d5d985db1b3d21f3619ef392e185e71b8d52bceacc4a7771ea7e"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"},
|
||||||
{file = "pydantic_core-2.18.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3"},
|
{file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"},
|
||||||
{file = "pydantic_core-2.18.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374"},
|
{file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"},
|
||||||
{file = "pydantic_core-2.18.2.tar.gz", hash = "sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e"},
|
{file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2519,13 +2544,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "requests"
|
name = "requests"
|
||||||
version = "2.32.2"
|
version = "2.32.3"
|
||||||
description = "Python HTTP for Humans."
|
description = "Python HTTP for Humans."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
|
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||||
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
|
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2892,17 +2917,17 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sympy"
|
name = "sympy"
|
||||||
version = "1.12"
|
version = "1.12.1"
|
||||||
description = "Computer algebra system (CAS) in Python"
|
description = "Computer algebra system (CAS) in Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
|
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
|
||||||
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
|
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=1.1.0,<1.4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tbb"
|
name = "tbb"
|
||||||
@ -3057,31 +3082,31 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torch"
|
name = "torch"
|
||||||
version = "2.3.0"
|
version = "2.3.1"
|
||||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"},
|
{file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"},
|
||||||
{file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"},
|
{file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"},
|
||||||
{file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"},
|
{file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"},
|
||||||
{file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"},
|
{file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"},
|
||||||
{file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"},
|
{file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"},
|
||||||
{file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"},
|
{file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"},
|
||||||
{file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"},
|
{file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"},
|
||||||
{file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"},
|
{file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"},
|
||||||
{file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"},
|
{file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"},
|
||||||
{file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"},
|
{file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"},
|
||||||
{file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"},
|
{file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"},
|
||||||
{file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
|
{file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"},
|
||||||
{file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"},
|
{file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"},
|
||||||
{file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"},
|
{file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"},
|
||||||
{file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"},
|
{file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"},
|
||||||
{file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"},
|
{file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"},
|
||||||
{file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"},
|
{file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"},
|
||||||
{file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"},
|
{file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"},
|
||||||
{file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"},
|
{file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"},
|
||||||
{file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"},
|
{file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -3102,7 +3127,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"
|
|||||||
nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||||
nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||||
sympy = "*"
|
sympy = "*"
|
||||||
triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
|
triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
|
||||||
typing-extensions = ">=4.8.0"
|
typing-extensions = ">=4.8.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@ -3201,17 +3226,17 @@ vision = ["Pillow (>=10.0.1,<=15.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "triton"
|
name = "triton"
|
||||||
version = "2.3.0"
|
version = "2.3.1"
|
||||||
description = "A language and compiler for custom Deep Learning operations"
|
description = "A language and compiler for custom Deep Learning operations"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
|
{file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"},
|
||||||
{file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
|
{file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"},
|
||||||
{file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
|
{file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"},
|
||||||
{file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"},
|
{file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"},
|
||||||
{file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"},
|
{file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"},
|
||||||
{file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"},
|
{file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -3244,13 +3269,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.12.0"
|
version = "4.12.2"
|
||||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"},
|
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
|
||||||
{file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"},
|
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3596,20 +3621,20 @@ multidict = ">=4.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zipp"
|
name = "zipp"
|
||||||
version = "3.19.0"
|
version = "3.19.2"
|
||||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
|
{file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"},
|
||||||
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
|
{file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "f54475fbf8d56bf0ff0b26fbc12afb160648dd03aaefbd8ae636b920964d39f0"
|
content-hash = "9c96638e5732ae8e97b76eec03e63ebb282bd24b2306e449a96a57f64dfdd53b"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
@ -26,6 +26,8 @@ optimum-habana = "1.11.1"
|
|||||||
transformers = "4.38.2"
|
transformers = "4.38.2"
|
||||||
accelerate = "0.27.2"
|
accelerate = "0.27.2"
|
||||||
outlines= { version = "^0.0.36", optional = true }
|
outlines= { version = "^0.0.36", optional = true }
|
||||||
|
prometheus-client = "^0.20.0"
|
||||||
|
py-cpuinfo = "^9.0.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
grpcio-tools = "*"
|
grpcio-tools = "*"
|
||||||
|
@ -4,12 +4,12 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
|
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
|
||||||
attrs==23.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
attrs==23.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.6.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
datasets==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
datasets==2.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
diffusers==0.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
diffusers==0.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13"
|
dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -18,13 +18,13 @@ filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec[http]==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec[http]==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.64.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
|
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -50,13 +50,15 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
optimum-habana==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
|
optimum-habana==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
optimum==1.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
optimum==1.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
pluggy==1.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
psutil==5.9.8 ; python_version >= "3.9" and python_version < "3.13"
|
psutil==5.9.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "3.13"
|
pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyarrow==16.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
pyarrow==16.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
|
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -65,12 +67,12 @@ python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.1
|
|||||||
pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tbb==2021.12.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows"
|
tbb==2021.12.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows"
|
||||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
||||||
@ -78,11 +80,11 @@ tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
transformers==4.38.2 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.38.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers[sentencepiece]==4.38.2 ; python_version >= "3.9" and python_version < "3.13"
|
transformers[sentencepiece]==4.38.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
|
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.0 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -5,13 +5,13 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -28,9 +28,11 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -38,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -2,6 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
BLOCK_SIZE: int = 16
|
||||||
# Will be set in warmup
|
# Will be set in warmup
|
||||||
@ -24,6 +25,9 @@ class CacheManager:
|
|||||||
self.repeat_slots = repeat_slots
|
self.repeat_slots = repeat_slots
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
|
@ -21,8 +21,10 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
if not IS_XPU_SYSTEM:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
@ -38,58 +38,6 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=32000,
|
|
||||||
hidden_size=4096,
|
|
||||||
intermediate_size=11008,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_key_value_heads=None,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
eos_token_id=2,
|
|
||||||
pretraining_tp=1,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_scaling=None,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.pretraining_tp = pretraining_tp
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
@ -101,6 +49,13 @@ def load_attention(config, prefix, weights):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
elif config.model_type == "phi3":
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
@ -257,6 +212,14 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
|
if config.model_type == "phi3":
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMistralForCausalLM(torch.nn.Module):
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = MistralModel(
|
self.model = MistralModel(
|
||||||
prefix="model" if not prefix else f"{prefix}.model",
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
# TODO dirty hack for idefics2.
|
||||||
|
prefix=(
|
||||||
|
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
@ -24,7 +24,10 @@ import torch.distributed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
|
if not IS_XPU_SYSTEM:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
829
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
829
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
@ -0,0 +1,829 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Idefics2 model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
|
resolution.
|
||||||
|
|
||||||
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||||
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||||
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||||
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||||||
|
self.num_patches = self.num_patches_per_side**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
|
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
|
max_im_h // self.patch_size,
|
||||||
|
max_im_w // self.patch_size,
|
||||||
|
)
|
||||||
|
boundaries = torch.arange(
|
||||||
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||||
|
)
|
||||||
|
position_ids = torch.full(
|
||||||
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||||
|
|
||||||
|
bucket_coords_h = torch.bucketize(
|
||||||
|
fractional_coords_h, boundaries, right=True
|
||||||
|
)
|
||||||
|
bucket_coords_w = torch.bucketize(
|
||||||
|
fractional_coords_w, boundaries, right=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_ids = (
|
||||||
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||||
|
).flatten()
|
||||||
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
|
|
||||||
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = key_states.shape[-2]
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics2VisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2VisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Encoder(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2EncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = Idefics2VisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = Idefics2Encoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // patch_size,
|
||||||
|
pixel_values.size(3) // patch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_attention_mask = patch_attention_mask.to(
|
||||||
|
dtype=torch.bool, device=pixel_values.device
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
|
if not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
else:
|
||||||
|
patch_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
patch_attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.text_config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
start_shape = hidden_states.shape[:-1]
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
intermediate_size = gate_up_states.shape[-1] // 2
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||||
|
).view(*start_shape, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
"""
|
||||||
|
Idefics2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_idx = None
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.num_heads = config.perceiver_config.resampler_n_heads
|
||||||
|
self.head_size = config.perceiver_config.resampler_head_dim
|
||||||
|
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.attention_dropout = config.perceiver_config.attention_dropout
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
self.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.kv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = latents.size()
|
||||||
|
kv_seq_len = q_len + context.size()[1]
|
||||||
|
|
||||||
|
hidden_states = torch.concat([context, latents], dim=-2)
|
||||||
|
query_states = self.q_proj(latents)
|
||||||
|
kv = self.kv(hidden_states)
|
||||||
|
key_states, value_states = kv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_size)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
self.input_latents_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_latents_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.input_context_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_context_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.self_attn = Idefics2PerceiverAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
"""
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.input_latents_norm(latents)
|
||||||
|
context = self.input_context_norm(context)
|
||||||
|
|
||||||
|
latents = self.self_attn(
|
||||||
|
latents=latents,
|
||||||
|
context=context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
latents = residual + latents
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.post_attention_layernorm(latents)
|
||||||
|
latents = self.mlp(latents)
|
||||||
|
latents = residual + latents
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverResampler(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.hidden_act = config.perceiver_config.hidden_act
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
# Create Latents for Perceiver
|
||||||
|
self.latents = weights.get_tensor(f"{prefix}.latents")
|
||||||
|
|
||||||
|
# Create Transformer Blocks
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2PerceiverLayer(
|
||||||
|
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.text_config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# seq embed -> bsz seq embed
|
||||||
|
latents = self.latents.unsqueeze(0).expand(
|
||||||
|
(context.shape[0], *self.latents.size())
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_attention_mask = torch.ones(
|
||||||
|
(attention_mask.size(0), latents.size(1)),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
||||||
|
attention_mask = _prepare_4d_attention_mask(
|
||||||
|
attention_mask, latents.dtype, tgt_len=self.n_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_context = latents
|
||||||
|
for perceiver_layer in self.layers:
|
||||||
|
compressed_context = perceiver_layer(
|
||||||
|
compressed_context,
|
||||||
|
context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
compressed_context = self.norm(compressed_context)
|
||||||
|
|
||||||
|
return compressed_context
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = Idefics2MLP(
|
||||||
|
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.perceiver_resampler = Idefics2PerceiverResampler(
|
||||||
|
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states, attention_mask):
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
image_hidden_states = self.perceiver_resampler(
|
||||||
|
context=image_hidden_states, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
config.vision_config.use_medusa = config.use_medusa
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.use_medusa = config.use_medusa
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.vision_model = Idefics2VisionTransformer(
|
||||||
|
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.connector = Idefics2Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
# mask = input_ids == self.config.image_token_index
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
# Unused here
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -23,6 +23,10 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def load_vision_model(prefix, config, weights):
|
|
||||||
if config.model_type == "clip_vision_model":
|
|
||||||
from text_generation_server.models.custom_modeling.clip import (
|
|
||||||
CLIPVisionTransformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CLIPVisionTransformer(
|
|
||||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_text_model(prefix, config, weights):
|
|
||||||
if config.model_type == "llama":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
||||||
elif config.model_type == "mistral":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
||||||
FlashMistralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashMistralForCausalLM(prefix, config, weights)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextForConditionalGeneration(nn.Module):
|
class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
|
try:
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused for this model
|
||||||
|
pixel_attention_mask=None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||||
|
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
def load_text_model(prefix, config, weights, name=None):
|
||||||
|
if config.model_type == "llama":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||||
|
elif config.model_type == "mistral":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vision_model(prefix, config, weights):
|
||||||
|
if config.model_type == "clip_vision_model":
|
||||||
|
from text_generation_server.models.custom_modeling.clip import (
|
||||||
|
CLIPVisionTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CLIPVisionTransformer(
|
||||||
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
@ -33,6 +33,11 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
|||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
IS_CUDA_SYSTEM,
|
||||||
|
IS_ROCM_SYSTEM,
|
||||||
|
IS_XPU_SYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -752,7 +757,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
@ -772,7 +780,10 @@ class FlashCausalLM(Model):
|
|||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
torch.xpu.synchronize(self.device)
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
@ -780,12 +791,20 @@ class FlashCausalLM(Model):
|
|||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
total_gpu_memory = torch.cuda.get_device_properties(
|
||||||
|
self.device
|
||||||
|
).total_memory
|
||||||
|
|
||||||
free_memory = max(
|
free_memory = max(
|
||||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||||
)
|
)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||||
|
free_memory = int(total_gpu_memory * 0.5)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashModel is only available on GPU")
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
@ -816,6 +835,8 @@ class FlashCausalLM(Model):
|
|||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
|
@ -2,14 +2,13 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||||
from transformers.models.llama import LlamaTokenizer
|
from transformers.models.llama import LlamaTokenizer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
LlamaConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -19,6 +18,8 @@ from text_generation_server.utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -34,6 +35,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
@ -53,8 +57,17 @@ class FlashLlama(FlashCausalLM):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||||
|
# TODO Huge hack
|
||||||
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
config = LlamaConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -33,8 +33,9 @@ tracer = trace.get_tracer(__name__)
|
|||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
||||||
@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -32,6 +33,9 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers.models.qwen2 import Qwen2Tokenizer
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||||
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
|||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.models.qwen2 import Qwen2Config
|
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||||
|
|
||||||
tokenizer = Qwen2Tokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = Qwen2Config.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -33,6 +34,9 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ from text_generation_server.utils import (
|
|||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +37,9 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
@ -11,4 +11,7 @@ if cuda_graphs is not None:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cuda_graphs = None
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
51
server/text_generation_server/models/idefics2.py
Normal file
51
server/text_generation_server/models/idefics2.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2(VlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
size={"longest_edge": 448, "shortest_edge": 378},
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
model_cls=Idefics2ForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.text_model.model.layers),
|
||||||
|
model.text_model.model.num_key_value_heads,
|
||||||
|
model.text_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.language_model.model.layers),
|
||||||
|
model.language_model.model.num_key_value_heads,
|
||||||
|
model.language_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.language_model, "max_past", None)
|
||||||
|
@ -474,6 +474,8 @@ class Mamba(Model):
|
|||||||
self.cuda_graph_warmup(bs)
|
self.cuda_graph_warmup(bs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -27,7 +27,14 @@ class Model(ABC):
|
|||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
||||||
|
# TODO report this to transformers.
|
||||||
|
other_special_ids = {
|
||||||
|
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
||||||
|
}
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
self.all_special_ids.update(other_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def image_text_replacement(image_input, config, image_id) -> str:
|
||||||
|
if config.model_type == "idefics2":
|
||||||
|
# TODO technically depends on image splitting which is not implemented.
|
||||||
|
num_features = 320
|
||||||
|
return (
|
||||||
|
"<fake_token_around_image>"
|
||||||
|
+ "<image>" * num_features
|
||||||
|
+ "<fake_token_around_image>"
|
||||||
|
)
|
||||||
|
elif config.model_type == "llava_next":
|
||||||
|
height, width = image_input["image_sizes"][image_id]
|
||||||
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||||
|
return "<image>" * num_features
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unpadded_features(
|
||||||
|
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
current_height = npatches * num_patch_height
|
||||||
|
current_width = npatches * num_patch_width
|
||||||
|
|
||||||
|
aspect_ratio: float = width / height
|
||||||
|
current_aspect_ratio: float = current_width / current_height
|
||||||
|
if aspect_ratio > current_aspect_ratio:
|
||||||
|
new_height = (height * current_width) // width
|
||||||
|
current_height = new_height
|
||||||
|
else:
|
||||||
|
new_width = (width * current_height) // height
|
||||||
|
current_width = new_width
|
||||||
|
|
||||||
|
unpadded_features = current_height * current_width
|
||||||
|
newline_features = current_height
|
||||||
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_features(height: int, width: int, config) -> int:
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
# From config
|
# From config
|
||||||
# Hardcoded for CLIP for now
|
# Hardcoded for CLIP for now
|
||||||
@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
image_grid_pinpoints,
|
image_grid_pinpoints,
|
||||||
image_size,
|
image_size,
|
||||||
)
|
)
|
||||||
|
unpadded_features, newline_features = get_unpadded_features(
|
||||||
height_of_patch = math.ceil(height / width * npatches)
|
height, width, npatches, num_patch_height, num_patch_width
|
||||||
|
)
|
||||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
|
||||||
# They are only added after width
|
|
||||||
newline_features = height_of_patch * num_patch_width
|
|
||||||
# The base patch covers the entire image
|
# The base patch covers the entire image
|
||||||
base_features = npatches**2
|
base_features = npatches**2
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
|
||||||
# assert get_number_of_features(640, 640) == 2928
|
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashMistralBatch):
|
class VlmCausalLMBatch(FlashMistralBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
chunks = split(r.inputs)
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += chunk["content"]
|
||||||
@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
"Cannot process input image not starting with data:"
|
"Cannot process input image not starting with data:"
|
||||||
)
|
)
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
height, width = image_input["image_sizes"][0]
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
num_features = get_number_of_features(height, width, config)
|
|
||||||
full_text += "<image>" * num_features
|
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
if image_inputs:
|
if image_inputs:
|
||||||
image_inputs = {
|
image_input = image_inputs[0]
|
||||||
|
new_image_inputs = {
|
||||||
"pixel_values": torch.cat(
|
"pixel_values": torch.cat(
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
),
|
),
|
||||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
|
||||||
}
|
}
|
||||||
|
if "pixel_attention_mask" in image_input:
|
||||||
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||||
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
|
[img["image_sizes"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
image_inputs = new_image_inputs
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
|
if "pixel_attention_mask" in image_inputs:
|
||||||
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if "image_sizes" in image_inputs:
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.image_sizes = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return VlmCausalLMBatch
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: VlmCausalLMBatch
|
self, batch: VlmCausalLMBatch
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
padded_bs = bs
|
|
||||||
if bs == 3:
|
|
||||||
padded_bs = 4
|
|
||||||
elif 3 < bs <= 8:
|
|
||||||
padded_bs = 8
|
|
||||||
elif bs > 8:
|
|
||||||
padded_bs = (bs + 7) // 8 * 8
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
if batch.image_sizes is not None:
|
if batch.image_sizes is not None:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -5,6 +5,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
import signal
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -20,6 +21,21 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
|||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
|
||||||
|
class SignalHandler:
|
||||||
|
KEEP_PROCESSING = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||||
|
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||||
|
|
||||||
|
def exit_gracefully(self, signum, frame):
|
||||||
|
print(f"Exiting gracefully: Signal {signum}")
|
||||||
|
self.KEEP_PROCESSING = False
|
||||||
|
|
||||||
|
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -201,11 +217,8 @@ def serve(
|
|||||||
|
|
||||||
logger.info("Server started at {}".format(local_url))
|
logger.info("Server started at {}".format(local_url))
|
||||||
|
|
||||||
try:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await server.wait_for_termination()
|
await asyncio.sleep(0.5)
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Signal received. Shutting down")
|
|
||||||
await server.stop(0)
|
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
|
@ -68,7 +68,15 @@ def initialize_torch_distributed():
|
|||||||
if world_size > n_hpus:
|
if world_size > n_hpus:
|
||||||
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
|
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
import oneccl_bindings_for_pytorch
|
||||||
|
|
||||||
|
backend = "ccl"
|
||||||
|
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
||||||
|
os.environ["CCL_WORKER_COUNT"] = str(1)
|
||||||
|
except ImportError:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
|
options = None
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
|
@ -2,24 +2,36 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import math
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import (
|
||||||
|
IS_CUDA_SYSTEM,
|
||||||
|
IS_ROCM_SYSTEM,
|
||||||
|
IS_XPU_SYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
try:
|
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -40,7 +52,7 @@ try:
|
|||||||
)
|
)
|
||||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
try:
|
try:
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -80,6 +92,28 @@ def attention(
|
|||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return ipex.llm.functional.varlen_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
|
@ -1,4 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_xpu_available():
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||||
|
IS_XPU_SYSTEM = is_xpu_available()
|
||||||
|
@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@ -18,7 +20,14 @@ except ImportError:
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import (
|
||||||
|
IS_CUDA_SYSTEM,
|
||||||
|
IS_ROCM_SYSTEM,
|
||||||
|
IS_XPU_SYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
@ -437,7 +446,7 @@ class MedusaModel(torch.nn.Module):
|
|||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
for i in range(medusa_config["medusa_num_heads"])
|
for i in range(get_speculate())
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -534,7 +543,7 @@ class MedusaHeadV2(nn.Module):
|
|||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
self.n_medusa_heads = get_speculate()
|
||||||
|
|
||||||
assert medusa_config["medusa_num_layers"] == 1
|
assert medusa_config["medusa_num_layers"] == 1
|
||||||
self.linear = TensorParallelColumnLinear.load_multi(
|
self.linear = TensorParallelColumnLinear.load_multi(
|
||||||
@ -696,6 +705,19 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
@classmethod
|
||||||
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_gate_up(
|
||||||
|
prefix, quantize=config.quantize
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
@ -799,7 +821,15 @@ try:
|
|||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if IS_XPU_SYSTEM:
|
||||||
|
res_out = hidden_states
|
||||||
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
|
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||||
|
)
|
||||||
|
if residual is not None:
|
||||||
|
res_out = residual
|
||||||
|
return out, res_out
|
||||||
|
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -845,7 +875,20 @@ try:
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
if IS_XPU_SYSTEM:
|
||||||
|
residual_out = hidden_states
|
||||||
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
|
residual,
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
self.variance_epsilon,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
if residual is not None:
|
||||||
|
residual_out = residual
|
||||||
|
return out, residual_out
|
||||||
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -971,6 +1014,10 @@ try:
|
|||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
ipex.llm.functional.rotary_embedding(
|
||||||
|
query, key, sin, cos, query.size(-1), True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
@ -1090,6 +1137,7 @@ try:
|
|||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
|
|
||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
@ -132,13 +132,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
Frequency penalty as defined by OpenAI
|
Frequency penalty as defined by OpenAI in
|
||||||
|
https://platform.openai.com/docs/guides/text-generation/parameter-details
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frequency_penalty (`List[float]`):
|
frequency_penalty (`List[float]`):
|
||||||
@ -152,13 +155,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
batch_size, input_size = input_ids.size()
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
vocab_size = scores.size(1)
|
||||||
score = -torch.where(
|
|
||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
# Calculate the frequency for each token so far
|
||||||
|
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||||
|
token_freq.scatter_add_(
|
||||||
|
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||||
|
)
|
||||||
|
token_freq /= input_size
|
||||||
|
|
||||||
|
# Apply the frequency penalty to logits
|
||||||
|
scores -= token_freq * self.penalty_tensor
|
||||||
|
return scores
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.penalty = [self.penalty[i] for i in indices]
|
self.penalty = [self.penalty[i] for i in indices]
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
IS_CUDA_SYSTEM,
|
||||||
|
IS_ROCM_SYSTEM,
|
||||||
|
IS_XPU_SYSTEM,
|
||||||
|
)
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
@ -22,6 +28,10 @@ def reshape_and_cache(
|
|||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
elif IS_XPU_SYSTEM:
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("vllm is not supported on your system")
|
raise ValueError("vllm is not supported on your system")
|
||||||
|
|
||||||
@ -58,6 +68,22 @@ def attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
if IS_XPU_SYSTEM:
|
||||||
|
query = query.contiguous()
|
||||||
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -144,12 +144,22 @@ class StopSequenceCriteria:
|
|||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_ids: Optional[Union[Set[int], int]],
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
if eos_token_ids is None:
|
||||||
|
eos_token_ids = set()
|
||||||
|
elif isinstance(eos_token_ids, int):
|
||||||
|
eos_token_ids = set([eos_token_ids])
|
||||||
|
elif isinstance(eos_token_ids, set):
|
||||||
|
eos_token_ids = eos_token_ids
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
|
||||||
|
)
|
||||||
|
self.eos_token_ids = eos_token_ids
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
@ -161,7 +171,10 @@ class StoppingCriteria:
|
|||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if not self.ignore_eos_token and last_token == self.eos_token_id:
|
if isinstance(last_token, torch.Tensor):
|
||||||
|
last_token = last_token.item()
|
||||||
|
|
||||||
|
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
if self.stop_sequence_criterias:
|
if self.stop_sequence_criterias:
|
||||||
@ -182,9 +195,13 @@ class StoppingCriteria:
|
|||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [StopSequenceCriteria(sequence) for sequence in pb.stop_sequences]
|
stop_sequence_criterias = [
|
||||||
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
]
|
||||||
|
# TODO Hack because eos_token_id cannot be what we want.
|
||||||
|
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
tokenizer.eos_token_id,
|
eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
@ -274,7 +291,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any(x != 1.0 for x in temperature):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
]
|
]
|
||||||
@ -282,15 +299,15 @@ class HeterogeneousNextTokenChooser:
|
|||||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 0 for x in top_k]):
|
if any(x != 0 for x in top_k):
|
||||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in top_p]):
|
if any(x < 1.0 for x in top_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in typical_p]):
|
if any(x < 1.0 for x in typical_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||||
|
|
||||||
|
@ -141,6 +141,12 @@ class Weights:
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||||
|
|
||||||
|
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||||
|
|
||||||
|
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||||
"""
|
"""
|
||||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||||
already alternating Q,K,V within the main tensor
|
already alternating Q,K,V within the main tensor
|
||||||
@ -181,8 +187,8 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||||
single_size = total_size // 3
|
single_size = total_size // blocks
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
@ -192,10 +198,11 @@ class Weights:
|
|||||||
block_size = single_size // world_size
|
block_size = single_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q = slice_[start:stop]
|
tensors = []
|
||||||
k = slice_[start + single_size : stop + single_size]
|
for i in range(blocks):
|
||||||
v = slice_[start + 2 * single_size : stop + 2 * single_size]
|
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||||
weight = torch.cat([q, k, v], dim=0)
|
tensors.append(tensor)
|
||||||
|
weight = torch.cat(tensors, dim=0)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
weight = weight.to(dtype=self.dtype)
|
weight = weight.to(dtype=self.dtype)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user