mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
Merge branch 'main' into gaudi_backend_pa
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
commit
7914e980e2
2
.github/workflows/nix_build.yaml
vendored
2
.github/workflows/nix_build.yaml
vendored
@ -47,7 +47,7 @@ jobs:
|
|||||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
|
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
else
|
else
|
||||||
export TAG=nix-${{ github.ref_name }}
|
export TAG=${{ github.ref_name }}-nix
|
||||||
fi
|
fi
|
||||||
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
|
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
|
||||||
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
|
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
|
||||||
|
1
.github/workflows/nix_tests.yaml
vendored
1
.github/workflows/nix_tests.yaml
vendored
@ -7,6 +7,7 @@ on:
|
|||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
concurrency:
|
concurrency:
|
||||||
|
1
.github/workflows/tests.yaml
vendored
1
.github/workflows/tests.yaml
vendored
@ -8,6 +8,7 @@ on:
|
|||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "backends/**"
|
||||||
- "Cargo.lock"
|
- "Cargo.lock"
|
||||||
- "rust-toolchain.toml"
|
- "rust-toolchain.toml"
|
||||||
|
|
||||||
|
763
Cargo.lock
generated
763
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -18,7 +18,7 @@ RUN apt-get update -y \
|
|||||||
&& rm -rf /var/lib/apt/lists/* \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
&& apt-get clean
|
&& apt-get clean
|
||||||
|
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.0 --profile minimal -y
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
|
||||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
RUN cargo install cargo-chef --locked
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -3,7 +3,7 @@ ARG HABANA_VERSION=1.20.0
|
|||||||
ARG PYTORCH_VERSION=2.6.0
|
ARG PYTORCH_VERSION=2.6.0
|
||||||
|
|
||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
ARG PLATFORM=xpu
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -41,7 +41,7 @@ RUN mkdir -p llama.cpp \
|
|||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.0 --profile minimal -y
|
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
RUN cargo install cargo-chef --locked
|
RUN cargo install cargo-chef --locked
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ ARG actions_runtime_token
|
|||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.0 --profile minimal -y && \
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
|
||||||
chmod -R a+w /root/.rustup && \
|
chmod -R a+w /root/.rustup && \
|
||||||
chmod -R a+w /root/.cargo && \
|
chmod -R a+w /root/.cargo && \
|
||||||
cargo install sccache --locked
|
cargo install sccache --locked
|
||||||
|
@ -20,6 +20,7 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from transformers.models.llava_next.modeling_llava_next import (
|
from transformers.models.llava_next.modeling_llava_next import (
|
||||||
unpad_image,
|
unpad_image,
|
||||||
)
|
)
|
||||||
@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
|||||||
|
|
||||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
|
||||||
self,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
image_features: torch.Tensor,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
||||||
mask = input_ids == self.config.image_token_index
|
|
||||||
# Let's pray we have enabled enough slots !
|
|
||||||
try:
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@ -169,6 +153,92 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
|
||||||
|
def pack_image_features(
|
||||||
|
self,
|
||||||
|
image_features,
|
||||||
|
image_sizes,
|
||||||
|
vision_feature_select_strategy,
|
||||||
|
image_newline=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||||
|
List of image feature tensor, each contains all the visual feature of all patches.
|
||||||
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
|
Actual image size of each images (H, W).
|
||||||
|
vision_feature_select_strategy (`str`)
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
|
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||||
|
New line embedding vector.
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||||
|
feature_lens (`List[int]`)
|
||||||
|
token length of each image in image_features
|
||||||
|
"""
|
||||||
|
new_image_features = []
|
||||||
|
feature_lens = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
height = width = (
|
||||||
|
self.config.vision_config.image_size
|
||||||
|
// self.config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.config.image_grid_pinpoints,
|
||||||
|
self.config.vision_config.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
np.prod(image_feature.shape)
|
||||||
|
% (num_patch_height * num_patch_width * height * width)
|
||||||
|
!= 0
|
||||||
|
and vision_feature_select_strategy == "default"
|
||||||
|
):
|
||||||
|
logger.warning_once(
|
||||||
|
"Image feature shape does not line up with the provided patch size. "
|
||||||
|
"You may be using the `default` vision_feature_select_strategy with a"
|
||||||
|
" visual encoder that does not have CLS."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||||
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||||
|
if image_newline is not None:
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
image_newline[:, None, None]
|
||||||
|
.expand(*image_feature.shape[:-1], 1)
|
||||||
|
.to(image_feature.device, image_feature.dtype),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||||
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
if image_newline is not None:
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(image_feature, image_newline[None].to(image_feature)), dim=0
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
feature_lens.append(image_feature.size(0))
|
||||||
|
image_features = torch.cat(new_image_features, dim=0)
|
||||||
|
feature_lens = torch.tensor(
|
||||||
|
feature_lens, dtype=torch.long, device=image_features.device
|
||||||
|
)
|
||||||
|
return image_features, feature_lens
|
||||||
|
|
||||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
@ -303,61 +373,33 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||||
height = width = (
|
image_features, feature_lens = self.pack_image_features(
|
||||||
self.config.vision_config.image_size
|
image_features,
|
||||||
// self.config.vision_config.patch_size
|
image_sizes,
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
|
image_newline=self.image_newline,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_image_features = []
|
special_image_mask = (
|
||||||
for image_idx, image_feature in enumerate(image_features):
|
input_ids == self.config.image_token_index
|
||||||
if image_feature.shape[0] > 1:
|
).unsqueeze(-1)
|
||||||
base_image_feature = image_feature[0]
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
image_feature = image_feature[1:]
|
inputs_embeds.device
|
||||||
|
|
||||||
if height * width != base_image_feature.shape[0]:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of patches is not consistent with the image size."
|
|
||||||
)
|
|
||||||
|
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
||||||
image_sizes[image_idx].tolist(),
|
|
||||||
self.config.image_grid_pinpoints,
|
|
||||||
self.config.vision_config.image_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_feature = image_feature.view(
|
|
||||||
num_patch_height, num_patch_width, height, width, -1
|
|
||||||
)
|
|
||||||
image_feature = image_feature.permute(
|
|
||||||
4, 0, 2, 1, 3
|
|
||||||
).contiguous()
|
|
||||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
||||||
image_feature = unpad_image(
|
|
||||||
image_feature, image_sizes[image_idx]
|
|
||||||
)
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(
|
|
||||||
image_feature,
|
|
||||||
self.image_newline[:, None, None].expand(
|
|
||||||
*image_feature.shape[:-1], 1
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(base_image_feature, image_feature), dim=0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_feature = image_feature[0]
|
|
||||||
image_feature = torch.cat(
|
|
||||||
(image_feature, self.image_newline[None]), dim=0
|
|
||||||
)
|
|
||||||
new_image_features.append(image_feature)
|
|
||||||
image_features = torch.cat(new_image_features, dim=0)
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
inputs_embeds, image_features, input_ids
|
|
||||||
)
|
)
|
||||||
|
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
n_image_features = image_features.shape[0]
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
image_features = image_features.to(
|
||||||
|
inputs_embeds.device, inputs_embeds.dtype
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
|
special_image_mask, image_features
|
||||||
|
)
|
||||||
|
|
||||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||||
# generation with cache
|
# generation with cache
|
||||||
elif past_key_values is not None:
|
elif past_key_values is not None:
|
||||||
|
@ -424,6 +424,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
images.append(curr_image)
|
images.append(curr_image)
|
||||||
|
|
||||||
|
if is_warmup is True:
|
||||||
|
images += [images[0]] * (len(texts) - len(images))
|
||||||
|
|
||||||
missing_inputs = 0
|
missing_inputs = 0
|
||||||
dummy_images = None
|
dummy_images = None
|
||||||
if is_warmup is False:
|
if is_warmup is False:
|
||||||
@ -1549,7 +1552,7 @@ class VlmCausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||||
max_prefill_batch_size,
|
max_prefill_batch_size,
|
||||||
is_warmup=False,
|
is_warmup=True,
|
||||||
)
|
)
|
||||||
_, prefill_batch, _ = self.generate_token(
|
_, prefill_batch, _ = self.generate_token(
|
||||||
[batch], is_warmup=True
|
[batch], is_warmup=True
|
||||||
@ -1569,7 +1572,7 @@ class VlmCausalLM(Model):
|
|||||||
request,
|
request,
|
||||||
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
|
||||||
2,
|
2,
|
||||||
is_warmup=False,
|
is_warmup=True,
|
||||||
)
|
)
|
||||||
_, prefill_batch, _ = self.generate_token(
|
_, prefill_batch, _ = self.generate_token(
|
||||||
[batch], is_warmup=True
|
[batch], is_warmup=True
|
||||||
|
@ -477,6 +477,15 @@ Options:
|
|||||||
|
|
||||||
[env: ENABLE_PREFILL_LOGPROBS=]
|
[env: ENABLE_PREFILL_LOGPROBS=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## GRACEFUL_TERMINATION_TIMEOUT
|
||||||
|
```shell
|
||||||
|
-g, --graceful-termination-timeout <GRACEFUL_TERMINATION_TIMEOUT>
|
||||||
|
Change timeout of graceful termination of the TGI server
|
||||||
|
|
||||||
|
[env: GRACEFUL_TERMINATION_TIMEOUT=]
|
||||||
|
[default: 90]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HELP
|
## HELP
|
||||||
```shell
|
```shell
|
||||||
|
13
flake.lock
13
flake.lock
@ -853,11 +853,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1741141853,
|
"lastModified": 1742783666,
|
||||||
"narHash": "sha256-FauVtC+FbOgkKpGVuQTNxSqrvgbmVc7hFkjn/DacwMo=",
|
"narHash": "sha256-IwdSl51NL6V0f+mYXZR0UTKaGleOsk9zV3l6kt5SUWw=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "02edad1f19d6dec824e0812e4cdc0aa7930ff8ae",
|
"rev": "60766d63c227d576510ecfb5edd3a687d56f6bc7",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -978,16 +978,15 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1741617161,
|
"lastModified": 1742807335,
|
||||||
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
|
"narHash": "sha256-580CXhhCcK1cXRWahk/mDKo6zUsOD2JvNOg5SBBMAKc=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
|
"rev": "8d9ea4691f49369aa5d3230a51366ff6c744d658",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"ref": "kernels-0.2.0",
|
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
@ -892,6 +892,10 @@ struct Args {
|
|||||||
/// Using this flag reallows users to ask for them.
|
/// Using this flag reallows users to ask for them.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
enable_prefill_logprobs: bool,
|
enable_prefill_logprobs: bool,
|
||||||
|
|
||||||
|
/// Change timeout of graceful termination of the TGI server
|
||||||
|
#[clap(default_value = "90", long, short, env)]
|
||||||
|
graceful_termination_timeout: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -933,6 +937,7 @@ fn shard_manager(
|
|||||||
log_level: LevelFilter,
|
log_level: LevelFilter,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
|
graceful_termination_timeout: u64,
|
||||||
_shutdown_sender: mpsc::Sender<()>,
|
_shutdown_sender: mpsc::Sender<()>,
|
||||||
) {
|
) {
|
||||||
// Enter shard-manager tracing span
|
// Enter shard-manager tracing span
|
||||||
@ -1206,7 +1211,12 @@ fn shard_manager(
|
|||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
terminate("shard", p, Duration::from_secs(90)).unwrap();
|
terminate(
|
||||||
|
"shard",
|
||||||
|
p,
|
||||||
|
Duration::from_secs(graceful_termination_timeout),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1545,6 +1555,7 @@ fn spawn_shards(
|
|||||||
status_receiver: &mpsc::Receiver<ShardStatus>,
|
status_receiver: &mpsc::Receiver<ShardStatus>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
|
graceful_termination_timeout: u64,
|
||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
@ -1615,6 +1626,7 @@ fn spawn_shards(
|
|||||||
max_log_level,
|
max_log_level,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
graceful_termination_timeout,
|
||||||
shutdown_sender,
|
shutdown_sender,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
@ -2002,6 +2014,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let args: Args = Args::parse();
|
let args: Args = Args::parse();
|
||||||
|
|
||||||
|
let graceful_termination_timeout = args.graceful_termination_timeout;
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
// Filter events with LOG_LEVEL
|
||||||
let varname = "LOG_LEVEL";
|
let varname = "LOG_LEVEL";
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
@ -2266,6 +2280,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
&status_receiver,
|
&status_receiver,
|
||||||
status_sender,
|
status_sender,
|
||||||
running.clone(),
|
running.clone(),
|
||||||
|
graceful_termination_timeout,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// We might have received a termination signal
|
// We might have received a termination signal
|
||||||
@ -2310,7 +2325,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Graceful termination
|
// Graceful termination
|
||||||
terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
|
terminate(
|
||||||
|
"webserver",
|
||||||
|
webserver,
|
||||||
|
Duration::from_secs(graceful_termination_timeout),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
|
||||||
exit_code
|
exit_code
|
||||||
|
@ -18,8 +18,8 @@ final: prev: {
|
|||||||
src = final.fetchFromGitHub {
|
src = final.fetchFromGitHub {
|
||||||
owner = "huggingface";
|
owner = "huggingface";
|
||||||
repo = "transformers";
|
repo = "transformers";
|
||||||
rev = "v4.49.0";
|
rev = "v4.50.0";
|
||||||
hash = "sha256-drq7RWoRaRejiQjCUHIYuzaKa9rA4eQZI2do74scp1c=";
|
hash = "sha256-/scrMPUY43n+XAMbwWCtmiJKXscXGLrklyDg9XZTaqw=";
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
# Released on: 30 January, 2025
|
# Released on: 30 January, 2025
|
||||||
# https://releases.rs/docs/1.84.1/
|
# https://releases.rs/docs/1.84.1/
|
||||||
channel = "1.85.0"
|
channel = "1.85.1"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
||||||
|
Loading…
Reference in New Issue
Block a user