mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
adjust warmup and enable vlm
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f95aa42660
commit
36b6612f97
@ -23,7 +23,7 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
@ -172,7 +172,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -279,8 +279,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ from text_generation_server.layers import (
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
@ -678,23 +679,23 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
# hidden_states = hidden_states.unsqueeze(0)
|
||||
# bsz, q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
(
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
indices,
|
||||
) = cross_attention_states
|
||||
bs = cu_seqlen_q.size(0) - 1
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
key_states = self.k_proj(cross_attention_states)
|
||||
value_states = self.v_proj(cross_attention_states)
|
||||
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
|
||||
value_states = value_states.view(
|
||||
bs, -1, self.num_key_value_heads, self.head_size
|
||||
)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||
@ -705,9 +706,9 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
# execute sdpa
|
||||
query_states = query_states.unsqueeze(0).transpose(1, 2)
|
||||
key_states = key_states.unsqueeze(0).transpose(1, 2)
|
||||
value_states = value_states.unsqueeze(0).transpose(1, 2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query_states,
|
||||
@ -803,9 +804,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
cross_attention_states, # [ IB, ...]
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cross_attention_states is None:
|
||||
return hidden_states, residual
|
||||
@ -912,7 +914,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
@ -949,8 +951,6 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
max_q = cu_seqlen_q[-1].item()
|
||||
max_k = seqlen_k
|
||||
else:
|
||||
cu_seqlen_q = torch.arange(
|
||||
seqlen_q + 1, device=device, dtype=torch.int32
|
||||
@ -965,16 +965,12 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
max_q = seqlen_q
|
||||
max_k = seqlen_k
|
||||
indices = image_indices[:]
|
||||
|
||||
cross_attention_states = (
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
indices,
|
||||
)
|
||||
|
||||
@ -986,7 +982,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
|
@ -19,7 +19,7 @@ from torch import nn
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
@ -72,7 +72,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -85,7 +85,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
@ -110,7 +109,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
if lm_head_indices is not None:
|
||||
|
@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from text_generation_server.layers import (
|
||||
@ -742,7 +742,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -829,8 +829,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ from transformers.activations import ACT2FN
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from text_generation_server.layers import (
|
||||
@ -485,7 +485,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -573,8 +573,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ from text_generation_server.layers import (
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
@ -906,7 +907,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -937,8 +938,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
@ -482,7 +483,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
@ -512,8 +513,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
@ -61,7 +61,6 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
empty_cache,
|
||||
synchronize,
|
||||
@ -77,10 +76,6 @@ tracer = trace.get_tracer(__name__)
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
|
||||
def small_power_of_2(n: int):
|
||||
return 1 << ((n - 1).bit_length() - 1)
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int):
|
||||
global SLIDING_WINDOW
|
||||
SLIDING_WINDOW = sliding_window
|
||||
@ -91,40 +86,6 @@ def get_sliding_windows() -> int:
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||
import importlib.util
|
||||
|
||||
if importlib.util.find_spec("numa") is not None:
|
||||
import numa
|
||||
import psutil
|
||||
|
||||
nodes = numa.info.get_max_node() + 1
|
||||
rank_per_node = math.ceil(world_size / nodes)
|
||||
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
|
||||
node_id = int(rank_id / rank_per_node)
|
||||
rank_offset_per_node = rank_id % rank_per_node
|
||||
if os.getenv("OMP_NUM_THREADS") is None:
|
||||
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
|
||||
else:
|
||||
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
|
||||
if len(numa.memory.get_membind_nodes()) == nodes:
|
||||
numa.memory.set_membind_nodes((node_id))
|
||||
torch.set_num_threads(num_cpus_per_rank)
|
||||
if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
|
||||
cpu_start = num_cpus_per_rank * rank_offset_per_node
|
||||
numa.schedule.run_on_cpus(
|
||||
0,
|
||||
*(
|
||||
numa.info.node_to_cpus(node_id)[
|
||||
cpu_start : cpu_start + num_cpus_per_rank
|
||||
]
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
@ -1447,16 +1408,13 @@ class FlashCausalLM(Model):
|
||||
|
||||
def warmup(
|
||||
self,
|
||||
request: generate_pb2.WarmupRequest,
|
||||
batch: FlashCausalLMBatch,
|
||||
max_input_tokens: Optional[int],
|
||||
max_total_tokens: Optional[int],
|
||||
):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
max_input_tokens = request.max_input_tokens
|
||||
max_total_tokens = request.max_total_tokens
|
||||
batch = self.batch_type.from_pb(
|
||||
request.batch, self.tokenizer, self.dtype, self.device
|
||||
)
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
@ -1505,10 +1463,10 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||
if max_total_tokens is None or max_total_tokens == 0:
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.cache_lengths)
|
||||
|
||||
if max_input_tokens is None or max_input_tokens == 0:
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
|
||||
del _batch, batch
|
||||
|
@ -16,7 +16,8 @@ from text_generation_server.models.globals import PREFIX_CACHING
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -447,6 +448,10 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
@ -459,13 +464,15 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
**kwargs,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
|
@ -17,8 +17,8 @@ from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLM,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -279,6 +279,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
|
||||
cross_attention_states = batch.cross_attention_states
|
||||
|
||||
kwargs = {}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
kwargs["bypass_hpu_graphs"] = False
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -286,13 +290,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
cross_attention_states=cross_attention_states,
|
||||
adapter_data=adapter_data,
|
||||
# TODO list
|
||||
adapter_data=None,
|
||||
image_indices=batch.image_indices[:],
|
||||
**kwargs,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
|
@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
|
||||
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
|
||||
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
|
||||
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
||||
if max_batch_size_str is not None:
|
||||
MAX_BATCH_SIZE = int(max_batch_size_str)
|
||||
else:
|
||||
raise ValueError("MAX_BATCH_SIZE is not set")
|
||||
|
||||
|
||||
PREFILL_WARMUP_BATCH_SIZE_LIST = []
|
||||
PREFILL_WARMUP_SEQLEN_LIST = []
|
||||
@ -1464,6 +1460,11 @@ class VlmCausalLM(Model):
|
||||
batch = self.batch_from_pb(request.batch, is_warmup=True)
|
||||
max_input_tokens = request.max_input_tokens
|
||||
max_prefill_batch_size = batch.input_ids.shape[0]
|
||||
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
|
||||
if max_batch_size_str is not None:
|
||||
MAX_BATCH_SIZE = int(max_batch_size_str)
|
||||
else:
|
||||
raise ValueError("MAX_BATCH_SIZE is not set")
|
||||
|
||||
try:
|
||||
# max prefill batch size warmup
|
||||
|
@ -18,10 +18,11 @@ from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.models.globals import set_model_id, ATTENTION
|
||||
from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
@ -109,14 +110,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||
self.model.warmup(request)
|
||||
)
|
||||
if ATTENTION == "paged":
|
||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||
if (
|
||||
self.model.batch_type in VLM_BATCH_TYPES
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
||||
# W/A for the skip tokenizer path
|
||||
# We need to call make_tokenizer_optional after the warmup,
|
||||
# because router is not aware of that feature
|
||||
make_tokenizer_optional(self.model.tokenizer)
|
||||
# Override default values with None for clearer semantics.
|
||||
max_input_tokens = (
|
||||
request.max_input_tokens
|
||||
if request.HasField("max_input_tokens")
|
||||
else None
|
||||
)
|
||||
max_total_tokens = (
|
||||
request.max_total_tokens
|
||||
if request.HasField("max_total_tokens")
|
||||
else None
|
||||
)
|
||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||
self.model.warmup(batch, max_input_tokens, max_total_tokens)
|
||||
)
|
||||
else:
|
||||
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||
self.model.warmup(request)
|
||||
)
|
||||
|
||||
# W/A for the skip tokenizer path
|
||||
# We need to call make_tokenizer_optional after the warmup,
|
||||
# because router is not aware of that feature
|
||||
make_tokenizer_optional(self.model.tokenizer)
|
||||
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens,
|
||||
|
@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
SUPPORT_CHUNKING: Optional[bool] = None
|
||||
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||
|
||||
|
||||
def set_support_chunking(support_chunking: bool):
|
||||
global SUPPORT_CHUNKING
|
||||
SUPPORT_CHUNKING = support_chunking
|
||||
|
||||
|
||||
def get_support_chunking() -> bool:
|
||||
global SUPPORT_CHUNKING
|
||||
return SUPPORT_CHUNKING
|
||||
|
||||
|
||||
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||
global MAX_PREFILL_TOKENS
|
||||
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||
|
||||
|
||||
def get_max_prefill_tokens() -> int:
|
||||
global MAX_PREFILL_TOKENS
|
||||
return MAX_PREFILL_TOKENS
|
Loading…
Reference in New Issue
Block a user