diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 0e3487dc..3bdfdd83 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -23,7 +23,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -172,7 +172,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -279,8 +279,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index cf47208b..b26adad7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -31,6 +31,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, @@ -678,23 +679,23 @@ class MllamaTextCrossAttention(nn.Module): """Input shape: Batch x Time x Channel""" # hidden_states = hidden_states.unsqueeze(0) # bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(-1, self.num_heads, self.head_size) - query_states = self.q_norm(query_states) - ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, - max_q, - max_k, indices, ) = cross_attention_states + bs = cu_seqlen_q.size(0) - 1 + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bs, -1, self.num_heads, self.head_size) + query_states = self.q_norm(query_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size) + value_states = value_states.view( + bs, -1, self.num_key_value_heads, self.head_size + ) key_states = self.k_norm(key_states) # key_states = key_states.repeat(1, self.num_key_value_groups, 1) @@ -705,9 +706,9 @@ class MllamaTextCrossAttention(nn.Module): # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) # execute sdpa - query_states = query_states.unsqueeze(0).transpose(1, 2) - key_states = key_states.unsqueeze(0).transpose(1, 2) - value_states = value_states.unsqueeze(0).transpose(1, 2) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) fsdpa_op = ModuleFusedSDPA(FusedSDPA) attn_output = fsdpa_op( query_states, @@ -803,9 +804,10 @@ class FlashLlamaCrossLayer(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, cross_attention_states, # [ IB, ...] + prefill_cache_indices, + hpu_attention_meta, ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: return hidden_states, residual @@ -912,7 +914,7 @@ class FlashMllamaForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, @@ -949,8 +951,6 @@ class FlashMllamaForConditionalGeneration(nn.Module): ) * seqlen_k ) - max_q = cu_seqlen_q[-1].item() - max_k = seqlen_k else: cu_seqlen_q = torch.arange( seqlen_q + 1, device=device, dtype=torch.int32 @@ -965,16 +965,12 @@ class FlashMllamaForConditionalGeneration(nn.Module): ) * seqlen_k ) - max_q = seqlen_q - max_k = seqlen_k indices = image_indices[:] cross_attention_states = ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, - max_q, - max_k, indices, ) @@ -986,7 +982,7 @@ class FlashMllamaForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, adapter_data=adapter_data, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index b1f89eff..532f118f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,7 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -72,7 +72,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -85,7 +85,6 @@ class PaliGemmaForConditionalGeneration(nn.Module): inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: - max_s += 1 position_ids += 1 if pixel_values is not None: @@ -110,7 +109,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 923123d6..31a01d7c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -742,7 +742,7 @@ class Idefics2ForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -829,8 +829,7 @@ class Idefics2ForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 580398cb..ce5e8115 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -24,7 +24,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -485,7 +485,7 @@ class Idefics3ForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, @@ -573,8 +573,7 @@ class Idefics3ForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index efd9cccd..832efdfa 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -40,6 +40,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, @@ -906,7 +907,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, @@ -937,8 +938,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index b32ab577..856635fd 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.attention import ( Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, @@ -482,7 +483,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, @@ -512,8 +513,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): block_tables=block_tables, slots=slots, seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, + hpu_attention_meta=hpu_attention_meta, prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 3a0dc15e..4cdf2628 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -61,7 +61,6 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments - from text_generation_server.utils.import_utils import ( empty_cache, synchronize, @@ -77,10 +76,6 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None -def small_power_of_2(n: int): - return 1 << ((n - 1).bit_length() - 1) - - def set_sliding_window(sliding_window: int): global SLIDING_WINDOW SLIDING_WINDOW = sliding_window @@ -91,40 +86,6 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW -def init_cpu_threads_env(rank_id: int, world_size: int): - import importlib.util - - if importlib.util.find_spec("numa") is not None: - import numa - import psutil - - nodes = numa.info.get_max_node() + 1 - rank_per_node = math.ceil(world_size / nodes) - num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) - rank_offset_per_node = rank_id % rank_per_node - if os.getenv("OMP_NUM_THREADS") is None: - num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - else: - num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - if len(numa.memory.get_membind_nodes()) == nodes: - numa.memory.set_membind_nodes((node_id)) - torch.set_num_threads(num_cpus_per_rank) - if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True): - cpu_start = num_cpus_per_rank * rank_offset_per_node - numa.schedule.run_on_cpus( - 0, - *( - numa.info.node_to_cpus(node_id)[ - cpu_start : cpu_start + num_cpus_per_rank - ] - ), - ) - logger.info( - f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}" - ) - - @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -1447,16 +1408,13 @@ class FlashCausalLM(Model): def warmup( self, - request: generate_pb2.WarmupRequest, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], ): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() - max_input_tokens = request.max_input_tokens - max_total_tokens = request.max_total_tokens - batch = self.batch_type.from_pb( - request.batch, self.tokenizer, self.dtype, self.device - ) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -1505,10 +1463,10 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None or max_total_tokens == 0: + if max_total_tokens is None: max_total_tokens = sum(batch.cache_lengths) - if max_input_tokens is None or max_input_tokens == 0: + if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 del _batch, batch diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 5d4d68fd..7cff7797 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -16,7 +16,8 @@ from text_generation_server.models.globals import PREFIX_CACHING from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @@ -447,6 +448,10 @@ class FlashVlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, @@ -459,13 +464,15 @@ class FlashVlmCausalLM(FlashCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - seqlen=seqlen, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, image_grid_thw=batch.image_grid_thw, + **kwargs, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index f149d462..be67b6ae 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -17,8 +17,8 @@ from text_generation_server.models.flash_vlm_causal_lm import ( FlashVlmCausalLM, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.layers.attention import Seqlen - +from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata +import habana_frameworks.torch as htorch tracer = trace.get_tracer(__name__) @@ -279,6 +279,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states = batch.cross_attention_states + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = False + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -286,13 +290,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - seqlen=seqlen, - max_s=max_s, + seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=batch.hpu_attn_meta, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, - adapter_data=adapter_data, + # TODO list + adapter_data=None, image_indices=batch.image_indices[:], + **kwargs, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index 66e00171..1c7b12b8 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") -if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) -else: - raise ValueError("MAX_BATCH_SIZE is not set") + PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = [] @@ -1464,6 +1460,11 @@ class VlmCausalLM(Model): batch = self.batch_from_pb(request.batch, is_warmup=True) max_input_tokens = request.max_input_tokens max_prefill_batch_size = batch.input_ids.shape[0] + max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") + if max_batch_size_str is not None: + MAX_BATCH_SIZE = int(max_batch_size_str) + else: + raise ValueError("MAX_BATCH_SIZE is not set") try: # max prefill batch size warmup diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 7a8a51d6..6e470361 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -18,10 +18,11 @@ from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, ATTENTION from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -109,14 +110,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens, max_input_tokens, max_total_tokens = ( - self.model.warmup(request) - ) + if ATTENTION == "paged": + set_max_prefill_tokens(request.max_prefill_tokens) + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb_processor( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.model.config, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.dtype, + self.model.device, + ) - # W/A for the skip tokenizer path - # We need to call make_tokenizer_optional after the warmup, - # because router is not aware of that feature - make_tokenizer_optional(self.model.tokenizer) + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens + if request.HasField("max_input_tokens") + else None + ) + max_total_tokens = ( + request.max_total_tokens + if request.HasField("max_total_tokens") + else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) + else: + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(request) + ) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) return generate_pb2.WarmupResponse( max_supported_total_tokens=max_supported_total_tokens, diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 00000000..c227d30f --- /dev/null +++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS