diff --git a/Dockerfile.neuron b/Dockerfile.neuron new file mode 100644 index 00000000..7a65e73c --- /dev/null +++ b/Dockerfile.neuron @@ -0,0 +1,173 @@ +# Fetch and extract the TGI sources +FROM alpine AS tgi +RUN mkdir -p /tgi + +# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments +FROM alpine AS optimum-neuron +RUN mkdir -p /optimum-neuron +ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.0.28.tar.gz /optimum-neuron/sources.tar.gz +RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1 + +# Build cargo components (adapted from TGI original Dockerfile) +# Note: we cannot use the cargo-chef base image as it uses python 3.11 +FROM ubuntu:22.04 AS chef + +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + curl ca-certificates build-essential \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.80.1 --profile minimal -y +ENV PATH="/root/.cargo/bin:${PATH}" +RUN cargo install cargo-chef --locked + +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef AS planner +COPY backends/neuron/Cargo.toml Cargo.toml +COPY Cargo.lock Cargo.lock +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY router router +COPY backends backends +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + unzip python3-dev libssl-dev pkg-config \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY backends/neuron/Cargo.toml Cargo.toml +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.lock Cargo.lock +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY router router +COPY backends backends +COPY launcher launcher +# Remove this line once TGI has fixed the conflict +RUN cargo update ureq --precise 2.9.7 +RUN cargo build --release + +# Python base image +FROM ubuntu:22.04 AS base + +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + python3-pip \ + python3-setuptools \ + python-is-python3 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean +RUN pip3 --no-cache-dir install --upgrade pip + +# Python server build image +FROM base AS pyserver + +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + make \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN install -d /pyserver +WORKDIR /pyserver +COPY backends/neuron/server server +COPY proto proto +RUN pip3 install -r server/build-requirements.txt +RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package + +# Neuron base image (used for deployment) +FROM base AS neuron + +# Install system prerequisites +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + gnupg2 \ + wget \ + python3-dev \ + libexpat1 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - + +# Install neuronx packages +RUN apt-get update -y \ + && apt-get install -y --no-install-recommends \ + aws-neuronx-dkms=2.18.20.0 \ + aws-neuronx-collectives=2.22.33.0-d2128d1aa \ + aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \ + aws-neuronx-tools=2.19.0.0 \ + libxml2 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}" + +RUN pip3 install \ + neuronx-cc==2.15.143.0 \ + torch-neuronx==2.1.2.2.3.2 \ + transformers-neuronx==0.12.313 \ + neuronx-distributed==0.9.0 \ + libneuronxla==2.0.5347.0 \ + --extra-index-url=https://pip.repos.neuron.amazonaws.com + +# Install HuggingFace packages +RUN pip3 install \ + hf_transfer huggingface_hub + +# Install optimum-neuron +COPY --from=optimum-neuron /optimum-neuron optimum-neuron +RUN pip3 install ./optimum-neuron + +# TGI base env +ENV HUGGINGFACE_HUB_CACHE=/tmp \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +# Disable color logs as they are not supported by CloudWatch +ENV LOGURU_COLORIZE=NO +ENV LOG_COLORIZE=0 + +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +# Install python server +COPY --from=pyserver /pyserver/build/dist dist +RUN pip install dist/text_generation_server*.tar.gz + +# AWS Sagemaker compatible image +FROM neuron AS sagemaker + +COPY backends/neuron/sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] + +# Final image +FROM neuron + +COPY backends/neuron/tgi_env.py /tgi_env.py +COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] diff --git a/backends/neuron/Cargo.toml b/backends/neuron/Cargo.toml new file mode 100644 index 00000000..3b237eda --- /dev/null +++ b/backends/neuron/Cargo.toml @@ -0,0 +1,47 @@ +[workspace] +members = [ + "backends/v2", + "backends/grpc-metadata", + "launcher", + "router" +] +default-members = [ + "backends/v2", + "backends/grpc-metadata", + "launcher", + "router" +] +resolver = "2" + +[workspace.package] +version = "3.0.0" +edition = "2021" +authors = ["Olivier Dehaene"] +homepage = "https://github.com/huggingface/text-generation-inference" + +[workspace.dependencies] +base64 = "0.22.0" +tokenizers = { version = "0.20.0", features = ["http"] } +hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } +minijinja = { version = "2.2.0", features = ["json"] } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } + +[profile.release] +incremental = true + +[profile.release-binary] +inherits = "release" +debug = 1 +incremental = true +panic = "abort" + +[profile.release-opt] +inherits = "release" +debug = 0 +incremental = false +lto = "fat" +opt-level = 3 +codegen-units = 1 diff --git a/backends/neuron/Makefile b/backends/neuron/Makefile new file mode 100644 index 00000000..f7a33393 --- /dev/null +++ b/backends/neuron/Makefile @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) +mkfile_dir := $(dir $(mkfile_path)) +root_dir := "${mkfile_dir}/../.." + +.PHONY: image + +VERSION := $(shell gawk 'match($$0, /^version = "(.*)"/, a) {print a[1]}' ${root_dir}/Cargo.toml) + +image: + docker build --rm -f ${root_dir}/Dockerfile.neuron \ + --build-arg VERSION=$(VERSION) \ + --ulimit nofile=100000:100000 \ + -t text-generation-inference:$(VERSION)-neuron ${root_dir} + docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron diff --git a/backends/neuron/README.md b/backends/neuron/README.md new file mode 100644 index 00000000..55722c3b --- /dev/null +++ b/backends/neuron/README.md @@ -0,0 +1,25 @@ +# Text-generation-inference - Neuron backend for AWS Trainium and inferentia2 + +## Description + +This is the TGI backend for AWS Neuron Trainium and Inferentia family of chips. + +This backend is composed of: +- the AWS Neuron SDK, +- the legacy v2 TGI launcher and router, +- a neuron specific inference server for text-generation. + +## Usage + +Please refer to the official [documentation](https://huggingface.co/docs/text-generation-inference/backends/neuron). + +## Build your own image + +The simplest way to build TGI with the neuron backend is to use the provided `Makefile`: + +```shell +$ make -C backends/neuron image +``` + +Alternatively, you can build the image directly from the top directory using a command similar to the one defined +in the `Makefile` under the `image` target. diff --git a/backends/neuron/sagemaker-entrypoint.sh b/backends/neuron/sagemaker-entrypoint.sh new file mode 100644 index 00000000..a8a6a730 --- /dev/null +++ b/backends/neuron/sagemaker-entrypoint.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +if [[ -z "${HF_MODEL_ID}" ]]; then + echo "HF_MODEL_ID must be set" + exit 1 +fi +export MODEL_ID="${HF_MODEL_ID}" + +if [[ -n "${HF_MODEL_REVISION}" ]]; then + export REVISION="${HF_MODEL_REVISION}" +fi + +if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then + export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}" +fi + +if [[ -z "${MAX_BATCH_SIZE}" ]]; then + echo "MAX_BATCH_SIZE must be set to the model static batch size" + exit 1 +fi + +text-generation-launcher --port 8080 diff --git a/backends/neuron/server/.gitignore b/backends/neuron/server/.gitignore new file mode 100644 index 00000000..378eac25 --- /dev/null +++ b/backends/neuron/server/.gitignore @@ -0,0 +1 @@ +build diff --git a/backends/neuron/server/Makefile b/backends/neuron/server/Makefile new file mode 100644 index 00000000..24d5e0d1 --- /dev/null +++ b/backends/neuron/server/Makefile @@ -0,0 +1,57 @@ +# Initialize base variables +pkg_name := text_generation_server +BUILDDIR ?= $(CURDIR)/build +VERSION ?= 0.0.1 +mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) +mkfile_dir := $(dir $(mkfile_path)) +pkg_dir := $(BUILDDIR)/$(pkg_name) +pkg_dist := ${BUILDDIR}/dist/${pkg_name}-${VERSION}.tar.gz + +clean: + rm -rf $(BUILDDIR)/* + +# List static sources to be deployed in the package +src_dir := $(mkfile_dir)/$(pkg_name) +sources := $(wildcard $(src_dir)/*.py) +deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources)) + +# Static files are just copied + +define COPY + cp -f $< $@ +endef + +$(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml + mkdir -p $(BUILDDIR) + $(COPY) + sed -i -e 's/version = "VERSION"/version = \"${VERSION}\"/' $@ + +$(pkg_dir)/%.py: $(src_dir)/%.py + mkdir -p $(pkg_dir) + $(COPY) + +# Generated files are produced by grpcio tools + +# If not provided, get local proto files +ifndef PROTODIR +PROTODIR := $(mkfile_dir)/../../../proto +endif + +# Three python files are generated for each protobuf +protobufs := $(PROTODIR)/generate.proto +pkg_pb_dir := $(pkg_dir)/pb +generated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py)) +generated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base)) +generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi)) +generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py)) + +$(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto + mkdir -p $(pkg_pb_dir) + python -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \ + --grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^ + sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py + +${pkg_dist}: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources) + python -m build $(BUILDDIR) + +package: ${pkg_dist} diff --git a/backends/neuron/server/build-requirements.txt b/backends/neuron/server/build-requirements.txt new file mode 100644 index 00000000..2083bd73 --- /dev/null +++ b/backends/neuron/server/build-requirements.txt @@ -0,0 +1,3 @@ +build +grpcio-tools==1.53.0 +mypy-protobuf diff --git a/backends/neuron/server/pyproject.toml b/backends/neuron/server/pyproject.toml new file mode 100644 index 00000000..f37dc91f --- /dev/null +++ b/backends/neuron/server/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "text-generation-server" +version = "VERSION" +authors = [{name="David Corvoysier", email="david@huggingface.co" }] +description = "TGI compatible inference server for AWS Neuronx platforms" +dependencies = [ + 'protobuf > 3.20.1, < 4', + 'grpcio == 1.57.0', + 'grpcio-status == 1.48.2', + 'grpcio-reflection == 1.48.2', + 'grpc-interceptor == 0.15.2', + 'typer == 0.6.1', + 'safetensors', + 'loguru == 0.6.0' +] + +[tool.setuptools] +packages = ["text_generation_server", "text_generation_server.pb"] + +[project.scripts] +text-generation-server = 'text_generation_server.cli:app' diff --git a/backends/neuron/server/text_generation_server/cli.py b/backends/neuron/server/text_generation_server/cli.py new file mode 100644 index 00000000..409143a9 --- /dev/null +++ b/backends/neuron/server/text_generation_server/cli.py @@ -0,0 +1,111 @@ +import sys +from typing import Optional + +import typer +from loguru import logger + + +app = typer.Typer() + + +@app.command() +def serve( + model_id: str, + revision: Optional[str] = None, + sharded: bool = False, + trust_remote_code: bool = None, + uds_path: str = "/tmp/text-generation-server", + logger_level: str = "INFO", + json_output: bool = False, + otlp_endpoint: Optional[str] = None, + otlp_service_name: str = "text-generation-inference.server", + max_input_tokens: Optional[int] = None, +): + """This is the main entry-point for the server CLI. + + Args: + model_id (`str`): + The *model_id* of a model on the HuggingFace hub or the path to a local model. + revision (`Optional[str]`, defaults to `None`): + The revision of the model on the HuggingFace hub. + sharded (`bool`): + Whether the model must be sharded or not. Kept for compatibility with the + text-generation-launcher, but must be set to False. + trust-remote-code (`bool`): + Kept for compatibility with text-generation-launcher. Ignored. + uds_path (`Union[Path, str]`): + The local path on which the server will expose its google RPC services. + logger_level (`str`): + The server logger level. Defaults to *INFO*. + json_output (`bool`): + Use JSON format for log serialization. + otlp_endpoint (`Optional[str]`, defaults to `None`): + The Open Telemetry endpoint to use. + otlp_service_name (`Optional[str]`, defaults to `None`): + The name to use when pushing data to the Open Telemetry endpoint. + max_input_tokens (`Optional[int]`, defaults to `None`): + The maximum number of input tokens each request should contain. + """ + if sharded: + raise ValueError("Sharding is not supported.") + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation_server", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + if trust_remote_code is not None: + logger.warning("'trust_remote_code' argument is not supported and will be ignored.") + + # Import here after the logger is added to log potential import exceptions + from .server import serve + + serve(model_id, revision, uds_path) + + +@app.command() +def download_weights( + model_id: str, + revision: Optional[str] = None, + logger_level: str = "INFO", + json_output: bool = False, + auto_convert: Optional[bool] = None, + extension: Optional[str] = None, + trust_remote_code: Optional[bool] = None, + merge_lora: Optional[bool] = None, +): + """Download the model weights. + + This command will be called by text-generation-launcher before serving the model. + """ + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation_server", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + if extension is not None: + logger.warning("'extension' argument is not supported and will be ignored.") + if trust_remote_code is not None: + logger.warning("'trust_remote_code' argument is not supported and will be ignored.") + if auto_convert is not None: + logger.warning("'auto_convert' argument is not supported and will be ignored.") + if merge_lora is not None: + logger.warning("'merge_lora' argument is not supported and will be ignored.") + + # Import here after the logger is added to log potential import exceptions + from .model import fetch_model + + fetch_model(model_id, revision) diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py new file mode 100644 index 00000000..3ddee690 --- /dev/null +++ b/backends/neuron/server/text_generation_server/generator.py @@ -0,0 +1,636 @@ +import copy +import logging +import time +from abc import ABC +from enum import Enum +from typing import List, Optional, Tuple + +import torch +from loguru import logger +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers.generation import GenerationConfig + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.generation import TokenSelector + +from .model import get_export_kwargs_from_env +from .pb.generate_pb2 import ( + Batch, + CachedBatch, + FinishReason, + GeneratedText, + Generation, + InfoResponse, + Request, + Tokens, +) + + +# Disable optimum-neuron warnings as it seems to block the server after a while +optimum_logger = logging.getLogger("optimum.neuron") +optimum_logger.setLevel("CRITICAL") + + +class Generator(ABC): + """An abstract class to represent the workhorse behind TextGenerationService. + + Ideally, it should not rely on protobuf constructs, but in a first step it does. + Implementations would typically need a model and a tokenizer to implement the Generator methods. + """ + + @property + def info(self) -> InfoResponse: + """This should simply return the expected InfoResponse""" + raise NotImplementedError + + def warmup(self, batch: Batch) -> int: + """Verify if the hardware can support the target load. + + Args: + batch (`Batch`): + A batch corresponding to the maximum number of concurrent requests. + + Return: + The maximum number of tokens the model supports. + """ + raise NotImplementedError + + def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: + """Prefill is called whenever new requests need to be added. + + When this method returns successfully, a decode method will follow + with both the current and newly prefilled batch(es). + + Args: + batch (`Batch`): + A batch containing the new requests. + + Return: + A list of `Generation` for each request and a `CachedBatch` containing all pending requests. + """ + raise NotImplementedError + + def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]: + """Decode after a prefill or another decode.""" + raise NotImplementedError + + def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: + """Remove requests that are not listed from the specified batch""" + raise NotImplementedError + + def clear(self): + """Remove all requests from the generator""" + raise NotImplementedError + + @classmethod + def from_pretrained(cls, model_id: str, revision: Optional[str]): + """Factory method "a la transformers" """ + raise NotImplementedError + + +class Slot: + """Represents a slot in a static batch""" + + class State(Enum): + EMPTY = 0 + PAUSE = 1 + READY = 2 + + def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): + self._id = id + self._tokenizer = tokenizer + self.clear() + + def clear(self): + """Clear the slot and mark it as available.""" + self._state = Slot.State.EMPTY + self._batch_id = None + self._request_id = None + self._inputs = "" + self._truncate = 0 + self._generation_config = None + self._tokens = [] + self._mask = torch.tensor([]) + self._selector = None + self._generated_tokens = 0 + self._next_text_token_start = 0 + self._next_text_token_end = 0 + self._generated_text = "" + self._next_text = "" + + @property + def id(self) -> int: + return self._id + + @property + def state(self) -> "Slot.State": + return self._state + + @property + def batch_id(self) -> int: + return self._batch_id + + @property + def request_id(self) -> int: + return self._request_id + + @property + def cached_text(self) -> str: + return self._inputs + self._generated_text + + @property + def generation_config(self) -> GenerationConfig: + return self._generation_config + + @property + def generated_tokens(self) -> int: + return self._generated_tokens + + def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig): + """Assign a request to a slot. + + Args: + request (`Request`): + The request to be assigned. Contains the inputs and tokens selection parameters. + generation_config (`transformers.GenerationConfig`): + The base generation config (might be modified by the request generation parameters). + """ + self._state = Slot.State.READY + self._batch_id = batch_id + self._request_id = request.id + self._inputs = request.inputs + if request.truncate: + self._truncate = request.truncate + self._generation_config = copy.deepcopy(generation_config) + # Update generation config with request parameters + self._generation_config.do_sample = request.parameters.do_sample + if self._generation_config.do_sample: + if request.parameters.temperature != 0: + self._generation_config.temperature = request.parameters.temperature + if request.parameters.top_k != 0: + self._generation_config.top_k = request.parameters.top_k + if request.parameters.top_p != 0: + self._generation_config.top_p = request.parameters.top_p + if request.parameters.typical_p != 0: + self._generation_config.typical_p = request.parameters.typical_p + if request.parameters.repetition_penalty != 0: + self._generation_config.repetition_penalty = request.parameters.repetition_penalty + self.seed = request.parameters.seed + self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens + self._max_new_tokens = self._generation_config.max_new_tokens + stop_strings = request.stopping_parameters.stop_sequences + if stop_strings: + self._generation_config.stop_strings = stop_strings + + def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector): + """Reset the slot for the next generation. + + Args: + input_ids: (`torch.LongTensor`): + The new input_ids to use to generate the next token. + attention_mask: (`torch.LongTensor`): + The new attention_mask to use to generate the next token. + selector: (`optimum.neuron.generation.TokenSelector`): + An object implementing the updated token selection logic. + """ + self._tokens = input_ids.clone() + self._next_text_token_start = 0 + self._next_text_token_end = torch.numel(self._tokens) + self._next_text = "" + self._mask = attention_mask.clone() + self._selector = selector + + def pause(self, reset_on_pause: bool): + """Mark the current slot as paused for generation. + + Note that the KV cache for this slot will still be filled. + """ + if reset_on_pause: + # Drop the last token as it will be added back when resuming the slot + self._generated_tokens -= 1 + # Since generated tokens are now part of the prefill, we need to reevaluate + # max_new_tokens for the next generation + self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens + self._state = Slot.State.PAUSE + + def resume(self): + """Mark the slot as ready for generation.""" + self._state = Slot.State.READY + + def _decode_next_tokens( + self, + ) -> str: + """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False) + if new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + return "" + + # Compare the generated text with the one using only the tokens producing the last one + last_text = self._tokenizer.decode( + self._tokens[self._next_text_token_start : self._next_text_token_end], + skip_special_tokens=False, + ) + if len(new_text) == len(last_text): + # Nothing new was actually generated + return "" + # Return the decoded text and store its token offsets + self._next_text_token_start = self._next_text_token_end + self._next_text_token_end = torch.numel(self._tokens) + return new_text[len(last_text) :] + + def append(self, next_token: int) -> str: + """Append a new generated token to this slot + + The new token is added to the list of generated tokens, which impacts + directly the generated_text and stopped property. + + The new token is however not added immediately to the slot inputs: it will + be added later on when it has effectively been used to produce the next token. + + Args: + next_token (`int`): + The newly generated token. + + Return: + The corresponding decoded text (if any). + """ + self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])]) + self._mask = torch.cat([self._mask, torch.LongTensor([1])]) + self._generated_tokens += 1 + next_text = self._decode_next_tokens() + # Now that a new token has been generated, we can append the previous one to the generated text + self._generated_text += self._next_text + self._next_text = next_text + return next_text + + def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.LongTensor: + """Select the next token from the candidate logits. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation (not used in all generation modes). + logits (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The logits corresponding to the generated tokens. + + Return: + `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token. + """ + return self._selector.select(input_ids, logits)[0] + + @property + def stopped(self) -> bool: + # Transformers stopping criteria expects a batch of input ids + input_ids = torch.unsqueeze(self._tokens, dim=0) + return self._selector.stopping_criteria(input_ids, None) + + @property + def generated_text(self) -> str: + return self._generated_text + self._next_text + + @property + def next_token(self) -> int: + return None if len(self._tokens) == 0 else self._tokens[-1] + + @property + def attention_mask(self) -> torch.LongTensor: + return self._mask + + @property + def max_token(self) -> int: + return self._generation_config.max_length + + @property + def max_new_tokens(self) -> int: + # The current value of max_new_tokens: might be different of the target max_new_tokens + # if the slot has been paused and resumed. + return self._generation_config.max_new_tokens + + @property + def truncate(self) -> int: + return self._truncate + + +class NeuronGenerator(Generator): + """A Generator for Neuron models.""" + + def __init__( + self, + model: NeuronModelForCausalLM, + tokenizer: PreTrainedTokenizerBase, + ): + self.model = model + self.rebuild_cache_on_prefill = not self.model.continuous_batching + # Specify padding and truncation options for decoder-only architecture + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" + self.tokenizer = tokenizer + self.special_tokens = self.tokenizer.all_special_ids + self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.batch_id = 0 + + @property + def info(self) -> InfoResponse: + """Returns the expected InfoResponse.""" + dtype = getattr(self.model.config, "torch_dtype", "float32") + return InfoResponse( + requires_padding=True, + dtype=str(dtype), + device_type="xla", + ) + + def warmup(self, batch: Batch) -> int: + """Verify if the hardware can support the target load. + + Args: + batch (`Batch`): + A batch corresponding to the maximum number of concurrent requests. + + Return: + The maximum number of tokens the model supports. + """ + # Just check that the warmup request parameters match the model capacity + batch_size = self.model.batch_size + if len(batch.requests) > batch_size: + raise ValueError( + f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." + ) + self.prefill(batch) + self.clear() + return self.model.batch_size * self.model.max_length + + def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: + """Prefill new requests. + + Args: + batch (`Batch`): + A batch containing the new requests. + + Return: + A list of `Generation` for each request and a `CachedBatch` containing all pending requests. + """ + slots = {state: [] for state in Slot.State} + for slot in self.slots: + slots[slot.state].append(slot) + active_slots = slots[Slot.State.READY] + empty_slots = slots[Slot.State.EMPTY] + if len(empty_slots) < len(batch.requests): + raise ValueError( + f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." + f" Please align max_batch_size with the static batch size: {self.model.batch_size}." + ) + # Assign each request to an empty slot + logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)") + new_slots = [] + for request in batch.requests: + slot = empty_slots.pop() + slot.assign(self.batch_id, request, self.model.generation_config) + new_slots.append(slot) + logger.debug( + f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" + ) + if self.rebuild_cache_on_prefill: + # We will clear pending slots and prefill all slots + prefill_slots = self.slots + seq_ids = None + else: + # We only need to pass inputs for the new requests + prefill_slots = new_slots + seq_ids = torch.tensor([slot.id for slot in prefill_slots]) + # Reconstruct the full inputs (without padding) as seen by the model. + # This comprises: + # - the inputs for new requests, + # - only when rebuilding the cache, the inputs and the generated text that has already + # been cached (i.e. excluding the last generated token) for unfinished requests. + inputs = [] + max_length = 0 + for slot in prefill_slots: + inputs.append(slot.cached_text) + # Apply truncation, making sure we fit into static dimensions + if slot.truncate == 0: + max_length = self.model.max_length + elif slot.truncate > max_length and slot.truncate < self.model.max_length: + max_length = slot.truncate + # Tokenize with padding and truncation + padded_inputs = self.tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length + ) + input_ids = padded_inputs.input_ids + attention_mask = padded_inputs.attention_mask + # Pause previously active slots during generation + next_tokens = [] + for slot in active_slots: + slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) + if self.rebuild_cache_on_prefill: + # The slot will be reset, so we need to store its next token + next_tokens.append(slot.next_token) + # Each slot must be reset with the padded inputs and masks + for i, slot in enumerate(prefill_slots): + if slot.state != slot.state.EMPTY: + if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]: + # Apply per-request truncation + input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id + attention_mask[i, : -slot.truncate] = 0 + slot_input_ids = input_ids[i : i + 1, :] + # Padded input ids are also required to set logits processors and stopping criterias + selector = TokenSelector.create( + slot_input_ids, + slot.generation_config, + self.model, + self.model.max_length, + tokenizer=self.tokenizer, + seed=slot.seed, + ) + slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) + slot_attention_mask = attention_mask[i] + slot.reset(slot_input_ids, slot_attention_mask, selector) + # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, + # as they have already been generated and sent back in the last decode. + model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids) + logits = self.model(**model_inputs)[0] + generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids) + self.batch_id += 1 + # Reactivate previously active slots for the next decode + for i, slot in enumerate(active_slots): + slot.resume() + if self.rebuild_cache_on_prefill: + # Append back the next token + slot.append(next_tokens[i]) + logger.debug("Model ready for decoding") + if next_batch is not None: + logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}") + return generation, next_batch + + def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: + """Decode the specified prefilled requests. + + Args: + batches (`List[CachedBatch]`): + A list of previous batches containing the prefilled requests. + + Return: + A list of `Generation` for each request and a `CachedBatch` containing all pending requests. + """ + # batches contains a list composed of: + # - the batch id returned by the last decode, + # - the batch id(s) returned by the last prefill(s) + # Batches are always concatenated during prefill, so we can + # just carry on with decoding. We adopt the id of the first + # batch in the list as our next batch id. + next_batch_id = batches[0].id + request_ids = [] + for batch in batches: + request_ids += batch.request_ids + cleared_request_ids = [] + for slot in self.slots: + if slot.state == slot.State.READY and slot.request_id not in request_ids: + cleared_request_ids.append(slot.request_id) + slot.clear() + if len(cleared_request_ids) > 0: + logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.") + active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] + if len(active_slots) < len(request_ids): + raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") + if self.model.continuous_batching: + decode_slots = active_slots + seq_ids = torch.tensor([slot.id for slot in decode_slots]) + else: + decode_slots = self.slots + seq_ids = None + # Reconstruct input_ids and attention_mask from decode slots + n_slots = len(decode_slots) + input_ids = torch.full([n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64) + max_length = 0 + for slot in decode_slots: + max_length = max(max_length, slot.attention_mask.size(-1)) + attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) + for i, slot in enumerate(decode_slots): + if slot.state != Slot.State.EMPTY: + # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) + input_ids[i, 0] = slot.next_token + attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask + model_inputs = self.model.prepare_inputs_for_decode(input_ids, attention_mask, seq_ids) + logits = self.model(**model_inputs)[0] + return self._generate_token(decode_slots, next_batch_id, logits, input_ids) + + def _generate_token( + self, slots: List[Slot], next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor + ) -> Tuple[List[Generation], CachedBatch]: + generations = [] + active_slots = False + for i, slot in enumerate(slots): + if slot.state != Slot.State.READY: + continue + request_id = slot.request_id + next_token_logits = logits[i : i + 1, -1, :] + slot_input_ids = input_ids[i : i + 1, :] + next_token = slot.select(slot_input_ids, next_token_logits) + next_token_text = slot.append(next_token) + generated_text = None + finish_reason = None + if next_token == self.tokenizer.eos_token_id: + finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN + elif slot.stopped: + if slot.generated_tokens == slot.max_new_tokens: + finish_reason = FinishReason.FINISH_REASON_LENGTH + else: + finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE + if finish_reason is not None: + # We must include the generated text for each finished sequence in the response + generated_text = GeneratedText( + text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason + ) + logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens") + # mark the slot as available + slot.clear() + else: + active_slots = True + generations.append( + Generation( + request_id=request_id, + prefill_tokens=None, + tokens=Tokens( + ids=[next_token], + logprobs=[0], + texts=[next_token_text], + is_special=[next_token in self.special_tokens], + ), + generated_text=generated_text, + ) + ) + batch = None + if active_slots: + # Whatever initial batch these requests came from, we always return all pending requests in a single batch + request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] + batch = self._cached_batch(next_batch_id, request_ids) + else: + logger.debug("No more pending requests") + return generations, batch + + def _cached_batch(self, batch_id: int, request_ids: List): + size = len(request_ids) + max_tokens = size * self.model.max_length + return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) + + def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch: + """Remove requests that are not listed from the specified batch + + Args: + batch_id (`int`): + The id of a cached batch. + keep_ids(`List[int]`): + The list of requests that must be kept. + + Return: + A `CachedBatch` containing the pending requests. + """ + keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids] + self._clear(keep_slot_ids) + return self._cached_batch(batch_id, keep_request_ids) + + def clear(self, batch_id: Optional[int] = None): + """Remove a subset or all requests from the generator""" + keep_ids = [] + if batch_id is not None: + keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id] + return self._clear(keep_ids) + + def _clear(self, keep_slot_ids: List): + for slot in self.slots: + if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: + logger.debug(f"Removing slot {slot.id} with request {slot.request_id}") + slot.clear() + + @classmethod + def from_pretrained(cls, model_id: str, revision: str = None): + """Instantiate a NeuronGenerator. + + Args: + model_id (`str`): + A hub model id or the path to a local model. This path must also contain a Tokenizer. + revision (`Optional[str]`, defaults to `None`): + The revision of the model on the HuggingFace hub. + + Returns: + A NeuronGenerator. + """ + config = AutoConfig.from_pretrained(model_id) + neuron_config = getattr(config, "neuron", None) + start = time.time() + if neuron_config is None: + export_kwargs = get_export_kwargs_from_env() + logger.info(f"Exporting model to neuron with config: {export_kwargs}.") + model = NeuronModelForCausalLM.from_pretrained( + model_id, revision=revision, low_cpu_mem_usage=True, export=True, **export_kwargs + ) + else: + logger.info("Loading model on neuron devices (this can take a few minutes).") + model = NeuronModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, revision=revision) + end = time.time() + logger.info(f"Model successfully loaded in {end - start:.2f} s.") + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + return cls(model, tokenizer) diff --git a/backends/neuron/server/text_generation_server/interceptor.py b/backends/neuron/server/text_generation_server/interceptor.py new file mode 100644 index 00000000..ed29cdf2 --- /dev/null +++ b/backends/neuron/server/text_generation_server/interceptor.py @@ -0,0 +1,27 @@ +from typing import Any, Callable + +import grpc +from google.rpc import code_pb2, status_pb2 +from grpc_interceptor.server import AsyncServerInterceptor +from grpc_status import rpc_status +from loguru import logger + + +class ExceptionInterceptor(AsyncServerInterceptor): + async def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + try: + response = method(request_or_iterator, context) + return await response + except Exception as err: + method_name = method_name.split("/")[-1] + logger.exception(f"Method {method_name} encountered an error.") + + await context.abort_with_status( + rpc_status.to_status(status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))) + ) diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py new file mode 100644 index 00000000..e8cb34ee --- /dev/null +++ b/backends/neuron/server/text_generation_server/model.py @@ -0,0 +1,118 @@ +import os +import shutil +import time +from typing import Optional + +from huggingface_hub import snapshot_download +from huggingface_hub.constants import HF_HUB_CACHE +from loguru import logger +from transformers import AutoConfig + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.utils import get_hub_cached_entries + + +def get_export_kwargs_from_env(): + batch_size = os.environ.get("MAX_BATCH_SIZE", None) + if batch_size is not None: + batch_size = int(batch_size) + sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None) + if sequence_length is not None: + sequence_length = int(sequence_length) + num_cores = os.environ.get("HF_NUM_CORES", None) + if num_cores is not None: + num_cores = int(num_cores) + auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) + return { + "task": "text-generation", + "batch_size": batch_size, + "sequence_length": sequence_length, + "num_cores": num_cores, + "auto_cast_type": auto_cast_type, + } + + +def is_cached(model_id, neuron_config): + # Look for cached entries for the specified model + in_cache = False + entries = get_hub_cached_entries(model_id, "inference") + # Look for compatible entries + for entry in entries: + compatible = True + for key, value in neuron_config.items(): + # Only weights can be different + if key in ["checkpoint_id", "checkpoint_revision"]: + continue + if entry[key] != value: + compatible = False + if compatible: + in_cache = True + break + return in_cache + + +def log_cache_size(): + path = HF_HUB_CACHE + if os.path.exists(path): + usage = shutil.disk_usage(path) + gb = 2**30 + logger.info(f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G") + else: + raise ValueError(f"The cache directory ({path}) does not exist.") + + +def fetch_model( + model_id: str, + revision: Optional[str] = None, +) -> str: + """Fetch a neuron model. + + Args: + model_id (`str`): + The *model_id* of a model on the HuggingFace hub or the path to a local model. + revision (`Optional[str]`, defaults to `None`): + The revision of the model on the HuggingFace hub. + + Returns: + A string corresponding to the model_id or path. + """ + if not os.path.isdir("/sys/class/neuron_device/"): + raise SystemError("No neuron cores detected on the host.") + if os.path.isdir(model_id) and revision is not None: + logger.warning("Revision {} ignored for local model at {}".format(revision, model_id)) + revision = None + # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) + # Note that the model may already be present in the cache. + config = AutoConfig.from_pretrained(model_id, revision=revision) + neuron_config = getattr(config, "neuron", None) + if neuron_config is not None: + if os.path.isdir(model_id): + return model_id + # Prefetch the neuron model from the Hub + logger.info(f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}") + log_cache_size() + return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") + # Model needs to be exported: look for compatible cached entries on the hub + export_kwargs = get_export_kwargs_from_env() + export_config = NeuronModelForCausalLM.get_export_config(model_id, config, revision=revision, **export_kwargs) + neuron_config = export_config.neuron + if not is_cached(model_id, neuron_config): + hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" + neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" + error_msg = ( + f"No cached version found for {model_id} with {neuron_config}." + f"You can start a discussion to request it on {hub_cache_url}" + f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" + ) + raise ValueError(error_msg) + logger.warning(f"{model_id} is not a neuron model: it will be exported using cached artifacts.") + if os.path.isdir(model_id): + return model_id + # Prefetch weights, tokenizer and generation config so that they are in cache + log_cache_size() + start = time.time() + snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") + end = time.time() + logger.info(f"Model weights fetched in {end - start:.2f} s.") + log_cache_size() + return model_id diff --git a/backends/neuron/server/text_generation_server/server.py b/backends/neuron/server/text_generation_server/server.py new file mode 100644 index 00000000..8eb2592d --- /dev/null +++ b/backends/neuron/server/text_generation_server/server.py @@ -0,0 +1,89 @@ +import asyncio +from pathlib import Path +from typing import List + +from grpc import aio +from grpc_reflection.v1alpha import reflection +from loguru import logger + +from .generator import Generator, NeuronGenerator +from .interceptor import ExceptionInterceptor +from .pb import generate_pb2, generate_pb2_grpc + + +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): + def __init__(self, generator: Generator, server_urls: List[str]): + self.generator = generator + self.server_urls = server_urls + + async def Info(self, request, context): + return self.generator.info + + async def Health(self, request, context): + return generate_pb2.HealthResponse() + + async def ServiceDiscovery(self, request, context): + return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) + + async def ClearCache(self, request, context): + if request.HasField("id"): + self.generator.clear(request.id) + else: + self.generator.clear() + return generate_pb2.ClearCacheResponse() + + async def FilterBatch(self, request, context): + filtered_batch = self.generator.filter(request.batch_id, request.request_ids) + return generate_pb2.FilterBatchResponse(batch=filtered_batch) + + async def Warmup(self, request, context): + max_tokens = self.generator.warmup(request.batch) + return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens) + + async def Prefill(self, request, context): + generations, batch = self.generator.prefill(request.batch) + return generate_pb2.PrefillResponse(generations=generations, batch=batch) + + async def Decode(self, request, context): + generations, batch = self.generator.decode(request.batches) + return generate_pb2.DecodeResponse(generations=generations, batch=batch) + + +def serve( + model_id: str, + revision: str, + uds_path: Path, +): + async def serve_inner(model_id: str, revision: str): + unix_socket_template = "unix://{}-{}" + local_url = unix_socket_template.format(uds_path, 0) + server_urls = [local_url] + + try: + generator = NeuronGenerator.from_pretrained(model_id, revision) + except Exception: + logger.exception("Error when initializing model") + raise + + server = aio.server(interceptors=[ExceptionInterceptor()]) + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( + TextGenerationService(generator, server_urls), server + ) + SERVICE_NAMES = ( + generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + server.add_insecure_port(local_url) + + await server.start() + + logger.info("Server started at {}".format(local_url)) + + try: + await server.wait_for_termination() + except KeyboardInterrupt: + logger.info("Signal received. Shutting down") + await server.stop(0) + + asyncio.run(serve_inner(model_id, revision)) diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh new file mode 100755 index 00000000..b959a795 --- /dev/null +++ b/backends/neuron/tgi-entrypoint.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e -o pipefail -u + +export ENV_FILEPATH=$(mktemp) + +trap "rm -f ${ENV_FILEPATH}" EXIT + +touch $ENV_FILEPATH + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +${SCRIPT_DIR}/tgi_env.py $@ + +source $ENV_FILEPATH + +exec text-generation-launcher $@ diff --git a/backends/neuron/tgi_env.py b/backends/neuron/tgi_env.py new file mode 100755 index 00000000..ff647c98 --- /dev/null +++ b/backends/neuron/tgi_env.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +import argparse +import logging +import os +import sys +from typing import Any, Dict, List, Optional + +from huggingface_hub import constants +from transformers import AutoConfig + +from optimum.neuron.modeling_decoder import get_available_cores +from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.utils.version_utils import get_neuronxcc_version + + +logger = logging.getLogger(__name__) + +tgi_router_env_vars = ["MAX_BATCH_SIZE", "MAX_TOTAL_TOKENS", "MAX_INPUT_TOKENS", "MAX_BATCH_PREFILL_TOKENS"] +tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"] + +env_config_peering = [ + ("MAX_BATCH_SIZE", "batch_size"), + ("MAX_TOTAL_TOKENS", "sequence_length"), + ("HF_AUTO_CAST_TYPE", "auto_cast_type"), + ("HF_NUM_CORES", "num_cores"), +] + +# By the end of this script all env var should be specified properly +env_vars = tgi_server_env_vars + tgi_router_env_vars + +available_cores = get_available_cores() +neuronxcc_version = get_neuronxcc_version() + + +def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace: + parser = argparse.ArgumentParser() + if not argv: + argv = sys.argv + # All these are params passed to tgi and intercepted here + parser.add_argument( + "--max-input-tokens", type=int, default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)) + ) + parser.add_argument("--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0)) + parser.add_argument("--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0)) + parser.add_argument("--max-batch-prefill-tokens", type=int, default=os.getenv("MAX_BATCH_PREFILL_TOKENS", 0)) + parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID")) + parser.add_argument("--revision", type=str, default=os.getenv("REVISION")) + + args = parser.parse_known_args(argv)[0] + + if not args.model_id: + raise Exception("No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var") + + # Override env with cmdline params + os.environ["MODEL_ID"] = args.model_id + + # Set all tgi router and tgi server values to consistent values as early as possible + # from the order of the parser defaults, the tgi router value can override the tgi server ones + if args.max_total_tokens > 0: + os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens) + + if args.max_input_tokens > 0: + os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens) + + if args.max_batch_size > 0: + os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size) + + if args.max_batch_prefill_tokens > 0: + os.environ["MAX_BATCH_PREFILL_TOKENS"] = str(args.max_batch_prefill_tokens) + + if args.revision: + os.environ["REVISION"] = str(args.revision) + + return args + + +def neuron_config_to_env(neuron_config): + with open(os.environ["ENV_FILEPATH"], "w") as f: + for env_var, config_key in env_config_peering: + f.write("export {}={}\n".format(env_var, neuron_config[config_key])) + max_input_tokens = os.getenv("MAX_INPUT_TOKENS") + if not max_input_tokens: + max_input_tokens = int(neuron_config["sequence_length"]) // 2 + if max_input_tokens == 0: + raise Exception("Model sequence length should be greater than 1") + f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens)) + max_batch_prefill_tokens = os.getenv("MAX_BATCH_PREFILL_TOKENS") + if not max_batch_prefill_tokens: + max_batch_prefill_tokens = int(neuron_config["batch_size"]) * int(max_input_tokens) + f.write("export MAX_BATCH_PREFILL_TOKENS={}\n".format(max_batch_prefill_tokens)) + + +def sort_neuron_configs(dictionary): + return -dictionary["num_cores"], -dictionary["batch_size"] + + +def lookup_compatible_cached_model(model_id: str, revision: Optional[str]) -> Optional[Dict[str, Any]]: + # Reuse the same mechanic as the one in use to configure the tgi server part + # The only difference here is that we stay as flexible as possible on the compatibility part + entries = get_hub_cached_entries(model_id, "inference") + + logger.debug("Found %d cached entries for model %s, revision %s", len(entries), model_id, revision) + + all_compatible = [] + for entry in entries: + if check_env_and_neuron_config_compatibility(entry, check_compiler_version=True): + all_compatible.append(entry) + + if not all_compatible: + logger.debug( + "No compatible cached entry found for model %s, env %s, available cores %s, neuronxcc version %s", + model_id, + get_env_dict(), + available_cores, + neuronxcc_version, + ) + return None + + logger.info("%d compatible neuron cached models found", len(all_compatible)) + + all_compatible = sorted(all_compatible, key=sort_neuron_configs) + + entry = all_compatible[0] + + return entry + + +def check_env_and_neuron_config_compatibility(neuron_config: Dict[str, Any], check_compiler_version: bool) -> bool: + logger.debug( + "Checking the provided neuron config %s is compatible with the local setup and provided environment", + neuron_config, + ) + + # Local setup compat checks + if neuron_config["num_cores"] > available_cores: + logger.debug("Not enough neuron cores available to run the provided neuron config") + return False + + if check_compiler_version and neuron_config["compiler_version"] != neuronxcc_version: + logger.debug( + "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", + neuronxcc_version, + neuron_config["compiler_version"], + ) + return False + + for env_var, config_key in env_config_peering: + neuron_config_value = str(neuron_config[config_key]) + env_value = os.getenv(env_var, str(neuron_config_value)) + if env_value != neuron_config_value: + logger.debug( + "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", + env_var, + config_key, + env_value, + neuron_config_value, + ) + return False + + max_input_tokens = int(os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))) + if max_input_tokens > 0: + sequence_length = neuron_config["sequence_length"] + if max_input_tokens >= sequence_length: + logger.debug( + "Specified max input tokens is not compatible with config sequence length ( %s >= %s)", + max_input_tokens, + sequence_length, + ) + return False + + return True + + +def get_env_dict() -> Dict[str, str]: + d = {} + for k in env_vars: + d[k] = os.getenv(k) + return d + + +def main(): + """ + This script determines proper default TGI env variables for the neuron precompiled models to + work properly + :return: + """ + args = parse_cmdline_and_set_env() + + for env_var in env_vars: + if not os.getenv(env_var): + break + else: + logger.info("All env vars %s already set, skipping, user know what they are doing", env_vars) + sys.exit(0) + + cache_dir = constants.HF_HUB_CACHE + + logger.info("Cache dir %s, model %s", cache_dir, args.model_id) + + config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) + neuron_config = getattr(config, "neuron", None) + if neuron_config is not None: + compatible = check_env_and_neuron_config_compatibility(neuron_config, check_compiler_version=False) + if not compatible: + env_dict = get_env_dict() + msg = ( + "Invalid neuron config and env. Config {}, env {}, available cores {}, neuronxcc version {}" + ).format(neuron_config, env_dict, available_cores, neuronxcc_version) + logger.error(msg) + raise Exception(msg) + else: + neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) + + if not neuron_config: + msg = ("No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}").format( + get_env_dict(), available_cores, neuronxcc_version + ) + logger.error(msg) + raise Exception(msg) + + neuron_config_to_env(neuron_config) + + +if __name__ == "__main__": + main() diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index e073353f..39f0ef4b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -50,6 +50,8 @@ title: Train Medusa title: Tutorials - sections: + - local: backends/neuron + title: Neuron - local: backends/trtllm title: TensorRT-LLM - local: backends/llamacpp diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md new file mode 100644 index 00000000..50f70fb2 --- /dev/null +++ b/docs/source/backends/neuron.md @@ -0,0 +1,182 @@ +# Neuron backend for AWS Trainium and Inferentia + +The Neuron backend allows the deployment of TGI on AWS Trainium and Inferentia family of chips. + +The following hardware targets are supported: +- Trainium 1, +- Inferentia 2. + +## Features + +The basic TGI features are supported: + +- continuous batching, +- token streaming, +- greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation). + + +## Deploy the service from the Hugging Face hub + +The simplest way to deploy the NeuronX TGI service for a specific model is to follow the +deployment instructions in the model card: + +- click on the "Deploy" button on the right, +- select your deployment service ("Inference Endpoints" and "SageMaker" are supported), +- select "AWS Trainum & Inferentia", +- follow the instructions. + + +## Deploy the service on a dedicated host + +The service is launched simply by running the text-generation-inference container with two sets of parameters: + +``` +docker run ghcr.io/huggingface/text-generation-inference:latest-neuron +``` + +- system parameters are used to map ports, volumes and devices between the host and the service, +- service parameters are forwarded to the `text-generation-launcher`. + +When deploying a service, you will need a pre-compiled Neuron model. The Neuron TGI backend supports two main modes of operation: + +- you can either deploy the service on a model that has already been exported to Neuron, +- or alternatively you can take advantage of the Neuron Model Cache to export your own model. + +### Common system parameters + +Whenever you launch a TGI service, we highly recommend you to mount a shared volume mounted as `/data` in the container: this is where +the models will be cached to speed up further instantiations of the service. + +Note also that enough neuron devices should be visible by the container.The simplest way to achieve that is to launch the service in `privileged` mode to get access to all neuron devices. +Alternatively, each device can be explicitly exposed using the `--device` option. + +Finally, you might want to export the `HF_TOKEN` if you want to access gated repositories. + +Here is an example of a service instantiation: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --privileged \ + -e HF_TOKEN=${HF_TOKEN} \ + ghcr.io/huggingface/text-generation-inference:latest-neuron \ + +``` + +If you only want to map the first device, the launch command becomes: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --device=/dev/neuron0 \ + -e HF_TOKEN=${HF_TOKEN} \ + ghcr.io/huggingface/text-generation-inference:latest-neuron \ + +``` + +### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) (recommended) + +We maintain a Neuron Model Cache of the most popular architecture and deployment parameters under [aws-neuron/optimum-neuron-cache](https://huggingface.co/aws-neuron/optimum-neuron-cache). + +If you just want to try the service quickly using a model without exporting it to Neuron first, it is thus still possible, pending some conditions: +- you must specify the export parameters when launching the service (or use default parameters), +- the model configuration must be cached. + +The snippet below shows how you can deploy a service from a hub standard model: + +``` +export HF_TOKEN= +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --privileged \ + -e HF_TOKEN=${HF_TOKEN} \ + -e HF_AUTO_CAST_TYPE="fp16" \ + -e HF_NUM_CORES=2 \ + ghcr.io/huggingface/text-generation-inference:latest-neuron:latest \ + --model-id meta-llama/Meta-Llama-3-8B \ + --max-batch-size 1 \ + --max-input-length 3164 \ + --max-total-tokens 4096 +``` + +### Using a model exported to a local path + +Alternatively, you can first [export the model to neuron format](https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-text-generation-inference:latest-neuron) locally. + +You can then deploy the service inside the shared volume: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --privileged \ + ghcr.io/huggingface/text-generation-inference:latest-neuron:latest \ + --model-id /data/ +``` + +Note: You don't need to specify any service parameters, as they will all be deduced from the model export configuration. + +### Using a neuron model from the 🤗 [HuggingFace Hub](https://huggingface.co/) + +The easiest way to share a neuron model inside your organization is to push it on the Hugging Face hub, so that it can be deployed directly without requiring an export. + +The snippet below shows how you can deploy a service from a hub neuron model: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --privileged \ + -e HF_TOKEN=${HF_TOKEN} \ + ghcr.io/huggingface/text-generation-inference:latest-neuron:latest \ + --model-id / +``` + +### Choosing service parameters + +Use the following command to list the available service parameters: + +``` +docker run ghcr.io/huggingface/text-generation-inference:latest-neuron --help +``` + +The configuration of an inference endpoint is always a compromise between throughput and latency: serving more requests in parallel will allow a higher throughput, but it will increase the latency. + +The neuron models have static input dimensions `[batch_size, max_length]`. + +This adds several restrictions to the following parameters: + +- `--max-batch-size` must be set to `batch size`, +- `--max-input-length` must be lower than `max_length`, +- `--max-total-tokens` must be set to `max_length` (it is per-request). + +Although not strictly necessary, but important for efficient prefilling: + +- `--max-batch-prefill-tokens` should be set to `batch_size` * `max-input-length`. + +### Choosing the correct batch size + +As seen in the previous paragraph, neuron model static batch size has a direct influence on the endpoint latency and throughput. + +Please refer to [text-generation-inference](https://github.com/huggingface/text-generation-inference) for optimization hints. + +Note that the main constraint is to be able to fit the model for the specified `batch_size` within the total device memory available +on your instance (16GB per neuron core, with 2 cores per device). + +## Query the service + +You can query the model using either the `/generate` or `/generate_stream` routes: + +``` +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + +``` +curl 127.0.0.1:8080/generate_stream \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + +Note: replace 127.0.0.1:8080 with your actual IP address and port.