mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge branch 'main' into add_vlm_chunking_optimized
This commit is contained in:
commit
8015f5f258
4
.github/workflows/build.yaml
vendored
4
.github/workflows/build.yaml
vendored
@ -45,7 +45,7 @@ jobs:
|
|||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
|
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
|
||||||
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||||
|
|
||||||
- name: Extract TensorRT-LLM version
|
- name: Extract TensorRT-LLM version
|
||||||
@ -223,7 +223,7 @@ jobs:
|
|||||||
PLATFORM=${{ env.PLATFORM }}
|
PLATFORM=${{ env.PLATFORM }}
|
||||||
build_type=${{ env.BUILD_TYPE }}
|
build_type=${{ env.BUILD_TYPE }}
|
||||||
sccache_gha_enabled=on
|
sccache_gha_enabled=on
|
||||||
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
|
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
|
||||||
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
||||||
target: ${{ env.TARGET }}
|
target: ${{ env.TARGET }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
|
@ -3,10 +3,9 @@ ARG cuda_base=12.8.0
|
|||||||
ARG build_type=release
|
ARG build_type=release
|
||||||
ARG ompi_version=4.1.7
|
ARG ompi_version=4.1.7
|
||||||
ARG sccache_gha_enabled=off
|
ARG sccache_gha_enabled=off
|
||||||
ARG actions_cache_url=""
|
ARG actions_results_url=""
|
||||||
ARG actions_runtime_token=""
|
ARG actions_runtime_token=""
|
||||||
|
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
# CUDA dependent dependencies resolver stage
|
||||||
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
|
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
|
||||||
|
|
||||||
@ -66,7 +65,7 @@ WORKDIR /usr/src/text-generation-inference
|
|||||||
ARG cuda_arch_list
|
ARG cuda_arch_list
|
||||||
ARG build_type
|
ARG build_type
|
||||||
ARG sccache_gha_enabled
|
ARG sccache_gha_enabled
|
||||||
ARG actions_cache_url
|
ARG actions_results_url
|
||||||
ARG actions_runtime_token
|
ARG actions_runtime_token
|
||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
@ -74,7 +73,7 @@ ENV PATH="/root/.cargo/bin:$PATH"
|
|||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
|
||||||
chmod -R a+w /root/.rustup && \
|
chmod -R a+w /root/.rustup && \
|
||||||
chmod -R a+w /root/.cargo && \
|
chmod -R a+w /root/.cargo && \
|
||||||
cargo install sccache --locked
|
cargo install sccache --version ">=0.10.0" --locked
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
@ -85,7 +84,7 @@ ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
|||||||
|
|
||||||
# SCCACHE Specifics args - before finding a better, more generic, way...
|
# SCCACHE Specifics args - before finding a better, more generic, way...
|
||||||
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
||||||
ENV ACTIONS_CACHE_URL=${actions_cache_url}
|
ENV ACTIONS_RESULTS_URL=${actions_results_url}
|
||||||
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
||||||
|
|
||||||
COPY Cargo.lock Cargo.lock
|
COPY Cargo.lock Cargo.lock
|
||||||
|
@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
|
|||||||
block_list: Optional[torch.Tensor]
|
block_list: Optional[torch.Tensor]
|
||||||
block_mapping: Optional[torch.Tensor]
|
block_mapping: Optional[torch.Tensor]
|
||||||
block_usage: Optional[torch.Tensor]
|
block_usage: Optional[torch.Tensor]
|
||||||
block_scales: Optional[torch.Tensor]
|
|
||||||
block_groups: Optional[torch.Tensor]
|
block_groups: Optional[torch.Tensor]
|
||||||
attn_bias: Optional[torch.Tensor]
|
attn_bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
|||||||
"block_list",
|
"block_list",
|
||||||
"block_mapping",
|
"block_mapping",
|
||||||
"block_usage",
|
"block_usage",
|
||||||
"block_scales",
|
|
||||||
"block_groups",
|
"block_groups",
|
||||||
"attn_bias",
|
"attn_bias",
|
||||||
],
|
],
|
||||||
|
@ -74,7 +74,6 @@ def paged_attention(
|
|||||||
block_list=hpu_attention_meta.block_list,
|
block_list=hpu_attention_meta.block_list,
|
||||||
block_mapping=hpu_attention_meta.block_mapping,
|
block_mapping=hpu_attention_meta.block_mapping,
|
||||||
block_bias=hpu_attention_meta.attn_bias,
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
block_scales=hpu_attention_meta.block_scales,
|
|
||||||
block_groups=hpu_attention_meta.block_groups,
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
matmul_qk_op=Matmul(),
|
matmul_qk_op=Matmul(),
|
||||||
|
@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
# bsz, q_len, _ = hidden_states.size()
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
(
|
(
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cross_attention_len,
|
||||||
cu_seqlen_k,
|
|
||||||
indices,
|
indices,
|
||||||
) = cross_attention_states
|
) = cross_attention_states
|
||||||
bs = cu_seqlen_q.size(0) - 1
|
bs = cross_attention_len.size(0)
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||||
query_states = self.q_norm(query_states)
|
query_states = self.q_norm(query_states)
|
||||||
@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
|
|
||||||
indices = cross_attention_states[-1]
|
indices = cross_attention_states[-1]
|
||||||
out_hidden_states = hidden_states[:]
|
out_hidden_states = hidden_states[:]
|
||||||
if len(indices) > 0:
|
|
||||||
assert max(indices) < hidden_states.shape[0]
|
|
||||||
hidden_states = hidden_states[indices]
|
hidden_states = hidden_states[indices]
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
image_indices=None,
|
indices=None,
|
||||||
|
cross_attention_len: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if cross_attention_states is not None:
|
if cross_attention_states is not None:
|
||||||
seqlen_q = len(image_indices)
|
|
||||||
n_images = cross_attention_states.shape[0]
|
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
|
||||||
device = cross_attention_states.device
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
offset = 0
|
|
||||||
cu_q = []
|
|
||||||
indices = []
|
|
||||||
for index in image_indices:
|
|
||||||
cu_q.append(offset)
|
|
||||||
length = seqlen.input_lengths[index].item()
|
|
||||||
assert index < seqlen.cu_seqlen_q.shape[0]
|
|
||||||
input_ids_offset = seqlen.cu_seqlen_q[index]
|
|
||||||
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
|
||||||
offset += length
|
|
||||||
cu_q.append(offset)
|
|
||||||
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
assert max(indices) < input_ids.shape[0]
|
|
||||||
|
|
||||||
cu_seqlen_k = (
|
|
||||||
torch.arange(
|
|
||||||
n_images + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
* seqlen_k
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
seqlen_q + 1, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
|
||||||
n_images = cross_attention_states.shape[0]
|
|
||||||
cu_seqlen_k = (
|
|
||||||
torch.arange(
|
|
||||||
n_images + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
* seqlen_k
|
|
||||||
)
|
|
||||||
indices = image_indices[:]
|
|
||||||
|
|
||||||
cross_attention_states = (
|
cross_attention_states = (
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cross_attention_len,
|
||||||
cu_seqlen_k,
|
|
||||||
indices,
|
indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -11,13 +11,18 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
prepare_for_decode,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING
|
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
synchronize,
|
||||||
|
)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -375,6 +380,91 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
def max_past(self) -> Optional[int]:
|
def max_past(self) -> Optional[int]:
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
return getattr(self.model.text_model, "max_past", None)
|
||||||
|
|
||||||
|
def warmup_decode(
|
||||||
|
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
||||||
|
):
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
|
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
||||||
|
# qwen2_vl and qwen2_5_vl case
|
||||||
|
position_ids = position_ids.unsqueeze(-1).repeat(
|
||||||
|
(1, batch.position_ids.shape[-1])
|
||||||
|
)
|
||||||
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
|
blocks[0] += block_num % batch_size
|
||||||
|
past_len = []
|
||||||
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
# fetch the last blocked to warmup block num
|
||||||
|
for i in range(batch_size):
|
||||||
|
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||||
|
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||||
|
block_tables.append(block_array)
|
||||||
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
|
start_idx += blocks[i]
|
||||||
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = torch.tensor(
|
||||||
|
past_len, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
hpu_attention_meta = prepare_for_decode(
|
||||||
|
self.dtype,
|
||||||
|
self.use_contiguous_pa,
|
||||||
|
self.device,
|
||||||
|
slots,
|
||||||
|
block_tables,
|
||||||
|
batch_size,
|
||||||
|
bucketing_ctx=None,
|
||||||
|
)
|
||||||
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots_tensor,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
lm_head_indices=None,
|
||||||
|
pixel_values=None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes=None,
|
||||||
|
image_grid_thw=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||||
|
warmup_times = 3
|
||||||
|
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
||||||
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
|
for i, (batch_size, block_num) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
|
):
|
||||||
|
if batch_size > block_num:
|
||||||
|
continue
|
||||||
|
log_master(
|
||||||
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
|
)
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
|
synchronize(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: FlashVlmCausalLMBatch,
|
batch: FlashVlmCausalLMBatch,
|
||||||
@ -450,17 +540,75 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = False
|
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||||
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
if self.bucketing_ctx is not None:
|
||||||
|
if batch.prefilling:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = input_lengths.shape[0]
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
|
if padded_bs != input_lengths.shape[0]:
|
||||||
|
padded_input_lengths = F.pad(
|
||||||
|
input_lengths,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
padded_cache_lengths_tensor = F.pad(
|
||||||
|
cache_lengths_tensor,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
padded_bs + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=padded_input_lengths,
|
||||||
|
cache_lengths=padded_cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
input_seq = input_ids.view(orig_bs, -1)
|
||||||
|
input_ids = F.pad(
|
||||||
|
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
|
)
|
||||||
|
if position_ids.dim() == 2:
|
||||||
|
# qwen2_vl and qwen2_5_vl case
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids,
|
||||||
|
(0, 0, 0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids,
|
||||||
|
(0, (padded_bs - orig_bs) * input_seq.shape[-1]),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
slots = F.pad(
|
||||||
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
lm_head_indices = F.pad(
|
||||||
|
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -476,8 +624,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
@ -486,4 +632,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return logits, speculative_logits
|
return logits[:orig_bs], (
|
||||||
|
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||||
|
)
|
||||||
|
@ -11,7 +11,9 @@ from opentelemetry import trace
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
prepare_for_decode,
|
||||||
|
)
|
||||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||||
FlashVlmCausalLMBatch,
|
FlashVlmCausalLMBatch,
|
||||||
FlashVlmCausalLM,
|
FlashVlmCausalLM,
|
||||||
@ -19,6 +21,13 @@ from text_generation_server.models.flash_vlm_causal_lm import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.models.globals import BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
synchronize,
|
||||||
|
)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -196,7 +205,178 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
|
||||||
|
):
|
||||||
|
if cross_attention_states is None:
|
||||||
|
return None, None, None
|
||||||
|
device = cross_attention_states.device
|
||||||
|
indices_list = []
|
||||||
|
if prefilling:
|
||||||
|
for i in image_indices:
|
||||||
|
indices_list.append(
|
||||||
|
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
|
||||||
|
)
|
||||||
|
indices = torch.cat(indices_list, dim=0)
|
||||||
|
else:
|
||||||
|
indices = image_indices[:]
|
||||||
|
return indices, seqlen.input_lengths.index_select(0, image_indices)
|
||||||
|
|
||||||
|
|
||||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||||
|
def warmup_decode(
|
||||||
|
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||||
|
):
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
|
blocks[0] += block_num % batch_size
|
||||||
|
past_len = []
|
||||||
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
# fetch the last blocked to warmup block num
|
||||||
|
for i in range(batch_size):
|
||||||
|
block_array = list(range(start_idx, start_idx + blocks[i]))
|
||||||
|
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
|
||||||
|
block_tables.append(block_array)
|
||||||
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
|
start_idx += blocks[i]
|
||||||
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = torch.tensor(
|
||||||
|
past_len, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
hpu_attention_meta = prepare_for_decode(
|
||||||
|
self.dtype,
|
||||||
|
self.use_contiguous_pa,
|
||||||
|
self.device,
|
||||||
|
slots,
|
||||||
|
block_tables,
|
||||||
|
batch_size,
|
||||||
|
bucketing_ctx=None,
|
||||||
|
)
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
|
image_indices = image_indices.repeat(batch_size)
|
||||||
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, 1, False
|
||||||
|
)
|
||||||
|
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots_tensor,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
lm_head_indices=None,
|
||||||
|
adapter_data=None,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
indices=indices,
|
||||||
|
cross_attention_len=cross_attention_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_prefill(
|
||||||
|
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
|
||||||
|
):
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||||
|
).repeat(batch_size)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
).repeat(batch_size)
|
||||||
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).reshape(batch_size, -1)
|
||||||
|
slot_acc = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
slots = []
|
||||||
|
for b in block_tables[i]:
|
||||||
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
|
slot_acc.extend(slots[:prompt_len])
|
||||||
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
|
||||||
|
input_lengths = (
|
||||||
|
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
batch_size, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
lm_head_indices = input_lengths - 1
|
||||||
|
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
|
image_indices = image_indices.repeat(batch_size)
|
||||||
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, prompt_len, True
|
||||||
|
)
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=None,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=None,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
indices=indices,
|
||||||
|
cross_attention_len=cross_attention_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||||
|
warmup_times = 3
|
||||||
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
|
):
|
||||||
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
|
for i, (batch_size, block_num) in enumerate(
|
||||||
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
|
):
|
||||||
|
if batch_size > block_num:
|
||||||
|
continue
|
||||||
|
log_master(
|
||||||
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
|
)
|
||||||
|
for index in range(warmup_times):
|
||||||
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
|
synchronize(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: FlashMllamaCausalLMBatch,
|
batch: FlashMllamaCausalLMBatch,
|
||||||
@ -263,12 +443,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
seqlen = Seqlen(
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths_tensor,
|
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
cross_attention_states = self.model.vision_forward(
|
cross_attention_states = self.model.vision_forward(
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
@ -281,11 +455,82 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
kwargs["bypass_hpu_graphs"] = False
|
kwargs["bypass_hpu_graphs"] = batch.prefilling
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
slots_pad = torch.zeros_like(input_ids)
|
slots_pad = torch.zeros_like(input_ids)
|
||||||
slots_pad[batch.prefill_cache_indices] = slots
|
slots_pad[batch.prefill_cache_indices] = slots
|
||||||
slots = slots_pad
|
slots = slots_pad
|
||||||
|
|
||||||
|
if self.bucketing_ctx is not None:
|
||||||
|
if batch.prefilling:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_prompt_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
input_lengths.shape[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_bs = input_lengths.shape[0]
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
|
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
|
if padded_bs != input_lengths.shape[0]:
|
||||||
|
padded_input_lengths = F.pad(
|
||||||
|
input_lengths,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
padded_cache_lengths_tensor = F.pad(
|
||||||
|
cache_lengths_tensor,
|
||||||
|
(0, padded_bs - orig_bs),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
padded_bs + 1, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
torch.cumsum(padded_input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=padded_input_lengths,
|
||||||
|
cache_lengths=padded_cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = F.pad(
|
||||||
|
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
|
||||||
|
)
|
||||||
|
position_ids = F.pad(
|
||||||
|
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
|
||||||
|
)
|
||||||
|
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
lm_head_indices = F.pad(
|
||||||
|
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||||
|
)
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
cross_attention_states = F.pad(
|
||||||
|
cross_attention_states,
|
||||||
|
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if len(image_indices) != 0:
|
||||||
|
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
|
||||||
|
image_indices = torch.cat((image_indices, pad_indices), dim=0)
|
||||||
|
else:
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states,
|
||||||
|
image_indices,
|
||||||
|
seqlen,
|
||||||
|
padded_input_len,
|
||||||
|
batch.prefilling,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -295,14 +540,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
# TODO list
|
# TODO list
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
image_indices=batch.image_indices[:],
|
cross_attention_states=cross_attention_states,
|
||||||
|
indices=indices,
|
||||||
|
cross_attention_len=cross_attention_len,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return logits, speculative_logits
|
return logits[:orig_bs], (
|
||||||
|
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||||
|
)
|
||||||
|
@ -119,6 +119,9 @@ struct Args {
|
|||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
|
#[clap(default_value = "9000", long, short, env)]
|
||||||
|
prometheus_port: u16,
|
||||||
|
|
||||||
/// Enable JSON output format.
|
/// Enable JSON output format.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
@ -317,6 +320,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
args.max_client_batch_size,
|
args.max_client_batch_size,
|
||||||
args.usage_stats,
|
args.usage_stats,
|
||||||
args.payload_limit,
|
args.payload_limit,
|
||||||
|
args.prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -37,6 +37,8 @@ struct Args {
|
|||||||
hostname: String,
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
#[clap(default_value = "9000", long, short, env)]
|
||||||
|
prometheus_port: u16,
|
||||||
#[clap(long, env, required = true)]
|
#[clap(long, env, required = true)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
|
prometheus_port,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
@ -322,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator {
|
|||||||
(required_blocks, repeats)
|
(required_blocks, repeats)
|
||||||
};
|
};
|
||||||
|
|
||||||
let tokens = tokens as usize;
|
let mut tokens = tokens as usize;
|
||||||
if required_blocks > self.free_blocks.len() as u32 {
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator {
|
|||||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
if self.is_hpu_device {
|
if self.is_hpu_device {
|
||||||
blocks.sort();
|
blocks.sort();
|
||||||
|
// need 1 slot for ping-pong optimization
|
||||||
|
tokens += 1;
|
||||||
}
|
}
|
||||||
let mut slots =
|
let mut slots =
|
||||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
@ -163,7 +163,7 @@ WORKDIR /usr/src/text-generation-inference
|
|||||||
ARG cuda_arch_list
|
ARG cuda_arch_list
|
||||||
ARG build_type
|
ARG build_type
|
||||||
ARG sccache_gha_enabled
|
ARG sccache_gha_enabled
|
||||||
ARG actions_cache_url
|
ARG actions_results_url
|
||||||
ARG actions_runtime_token
|
ARG actions_runtime_token
|
||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
@ -171,7 +171,7 @@ ENV PATH="/root/.cargo/bin:$PATH"
|
|||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||||
chmod -R a+w /root/.rustup && \
|
chmod -R a+w /root/.rustup && \
|
||||||
chmod -R a+w /root/.cargo && \
|
chmod -R a+w /root/.cargo && \
|
||||||
cargo install sccache --locked
|
cargo install sccache --version ">=0.10.0" --locked
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
|
@ -1,13 +1,68 @@
|
|||||||
|
use axum::{extract::Request, middleware::Next, response::Response};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
|
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
|
||||||
|
use opentelemetry::Context;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||||
|
|
||||||
|
struct TraceParent {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
version: u8,
|
||||||
|
trace_id: TraceId,
|
||||||
|
parent_id: SpanId,
|
||||||
|
trace_flags: TraceFlags,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
|
||||||
|
let parts: Vec<&str> = header_value.split('-').collect();
|
||||||
|
if parts.len() != 4 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let version = u8::from_str_radix(parts[0], 16).ok()?;
|
||||||
|
if version == 0xff {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let trace_id = TraceId::from_hex(parts[1]).ok()?;
|
||||||
|
let parent_id = SpanId::from_hex(parts[2]).ok()?;
|
||||||
|
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
|
||||||
|
|
||||||
|
Some(TraceParent {
|
||||||
|
version,
|
||||||
|
trace_id,
|
||||||
|
parent_id,
|
||||||
|
trace_flags: TraceFlags::new(trace_flags),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
|
||||||
|
let context = request
|
||||||
|
.headers()
|
||||||
|
.get("traceparent")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(parse_traceparent)
|
||||||
|
.map(|traceparent| {
|
||||||
|
Context::new().with_remote_span_context(SpanContext::new(
|
||||||
|
traceparent.trace_id,
|
||||||
|
traceparent.parent_id,
|
||||||
|
traceparent.trace_flags,
|
||||||
|
true,
|
||||||
|
Default::default(),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
|
request.extensions_mut().insert(context);
|
||||||
|
|
||||||
|
next.run(request).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
/// - otlp_service_name service name to appear in APM
|
/// - otlp_service_name service name to appear in APM
|
||||||
|
@ -67,16 +67,26 @@ pub(crate) async fn sagemaker_compatibility(
|
|||||||
default_return_full_text: Extension<bool>,
|
default_return_full_text: Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
context: Extension<Option<opentelemetry::Context>>,
|
||||||
info: Extension<Info>,
|
info: Extension<Info>,
|
||||||
Json(req): Json<SagemakerRequest>,
|
Json(req): Json<SagemakerRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
match req {
|
match req {
|
||||||
SagemakerRequest::Generate(req) => {
|
SagemakerRequest::Generate(req) => {
|
||||||
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
compat_generate(
|
||||||
|
default_return_full_text,
|
||||||
|
infer,
|
||||||
|
compute_type,
|
||||||
|
context,
|
||||||
|
Json(req),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
SagemakerRequest::Chat(req) => {
|
||||||
|
chat_completions(infer, compute_type, info, context, Json(req)).await
|
||||||
}
|
}
|
||||||
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
|
||||||
SagemakerRequest::Completion(req) => {
|
SagemakerRequest::Completion(req) => {
|
||||||
completions(infer, compute_type, info, Json(req)).await
|
completions(infer, compute_type, info, context, Json(req)).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ use crate::kserve::{
|
|||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
|
use crate::logging::trace_context_middleware;
|
||||||
use crate::sagemaker::{
|
use crate::sagemaker::{
|
||||||
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
||||||
__path_sagemaker_compatibility,
|
__path_sagemaker_compatibility,
|
||||||
@ -63,6 +64,7 @@ use tokio::sync::oneshot;
|
|||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
@ -125,6 +127,7 @@ pub(crate) async fn compat_generate(
|
|||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
context: Extension<Option<opentelemetry::Context>>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
@ -134,11 +137,14 @@ pub(crate) async fn compat_generate(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
Ok(
|
||||||
.await
|
generate_stream(infer, compute_type, context, Json(req.into()))
|
||||||
.into_response())
|
.await
|
||||||
|
.into_response(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
let (headers, Json(generation)) =
|
||||||
|
generate(infer, compute_type, context, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
@ -267,9 +273,14 @@ seed,
|
|||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
let (headers, _, response) =
|
let (headers, _, response) =
|
||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
|
||||||
Ok((headers, response))
|
Ok((headers, response))
|
||||||
@ -465,12 +476,17 @@ seed,
|
|||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||||
|
|
||||||
@ -700,9 +716,14 @@ pub(crate) async fn completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
let CompletionRequest {
|
let CompletionRequest {
|
||||||
@ -1148,9 +1169,14 @@ pub(crate) async fn chat_completions(
|
|||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(mut chat): Json<ChatRequest>,
|
Json(mut chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
model,
|
||||||
@ -2265,6 +2291,7 @@ async fn start(
|
|||||||
.layer(Extension(prom_handle.clone()))
|
.layer(Extension(prom_handle.clone()))
|
||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
.layer(DefaultBodyLimit::max(payload_limit))
|
.layer(DefaultBodyLimit::max(payload_limit))
|
||||||
|
.layer(axum::middleware::from_fn(trace_context_middleware))
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
@ -7,6 +7,7 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use axum::Json;
|
use axum::Json;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
@ -70,9 +71,14 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
pub(crate) async fn vertex_compatibility(
|
pub(crate) async fn vertex_compatibility(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
|
Loading…
Reference in New Issue
Block a user