adjust warmup and enable vlm

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-20 01:09:58 -07:00
parent f95aa42660
commit 36b6612f97
13 changed files with 134 additions and 109 deletions

View File

@ -23,7 +23,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -172,7 +172,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
@ -279,8 +279,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)

View File

@ -31,6 +31,7 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
@ -678,23 +679,23 @@ class MllamaTextCrossAttention(nn.Module):
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(-1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
) = cross_attention_states
bs = cu_seqlen_q.size(0) - 1
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(
bs, -1, self.num_key_value_heads, self.head_size
)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
@ -705,9 +706,9 @@ class MllamaTextCrossAttention(nn.Module):
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# execute sdpa
query_states = query_states.unsqueeze(0).transpose(1, 2)
key_states = key_states.unsqueeze(0).transpose(1, 2)
value_states = value_states.unsqueeze(0).transpose(1, 2)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query_states,
@ -803,9 +804,10 @@ class FlashLlamaCrossLayer(torch.nn.Module):
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states, # [ IB, ...]
prefill_cache_indices,
hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
@ -912,7 +914,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
@ -949,8 +951,6 @@ class FlashMllamaForConditionalGeneration(nn.Module):
)
* seqlen_k
)
max_q = cu_seqlen_q[-1].item()
max_k = seqlen_k
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
@ -965,16 +965,12 @@ class FlashMllamaForConditionalGeneration(nn.Module):
)
* seqlen_k
)
max_q = seqlen_q
max_k = seqlen_k
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
)
@ -986,7 +982,7 @@ class FlashMllamaForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,

View File

@ -19,7 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -72,7 +72,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
@ -85,7 +85,6 @@ class PaliGemmaForConditionalGeneration(nn.Module):
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
@ -110,7 +109,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:

View File

@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
@ -742,7 +742,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
@ -829,8 +829,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)

View File

@ -24,7 +24,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
@ -485,7 +485,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
@ -573,8 +573,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)

View File

@ -40,6 +40,7 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
@ -906,7 +907,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
@ -937,8 +938,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
@ -482,7 +483,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
@ -512,8 +513,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:

View File

@ -61,7 +61,6 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
@ -77,10 +76,6 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None
def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1)
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
@ -91,40 +86,6 @@ def get_sliding_windows() -> int:
return SLIDING_WINDOW
def init_cpu_threads_env(rank_id: int, world_size: int):
import importlib.util
if importlib.util.find_spec("numa") is not None:
import numa
import psutil
nodes = numa.info.get_max_node() + 1
rank_per_node = math.ceil(world_size / nodes)
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
node_id = int(rank_id / rank_per_node)
rank_offset_per_node = rank_id % rank_per_node
if os.getenv("OMP_NUM_THREADS") is None:
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
else:
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
if len(numa.memory.get_membind_nodes()) == nodes:
numa.memory.set_membind_nodes((node_id))
torch.set_num_threads(num_cpus_per_rank)
if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
cpu_start = num_cpus_per_rank * rank_offset_per_node
numa.schedule.run_on_cpus(
0,
*(
numa.info.node_to_cpus(node_id)[
cpu_start : cpu_start + num_cpus_per_rank
]
),
)
logger.info(
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
@ -1447,16 +1408,13 @@ class FlashCausalLM(Model):
def warmup(
self,
request: generate_pb2.WarmupRequest,
batch: FlashCausalLMBatch,
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
max_input_tokens = request.max_input_tokens
max_total_tokens = request.max_total_tokens
batch = self.batch_type.from_pb(
request.batch, self.tokenizer, self.dtype, self.device
)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
@ -1505,10 +1463,10 @@ class FlashCausalLM(Model):
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None or max_total_tokens == 0:
if max_total_tokens is None:
max_total_tokens = sum(batch.cache_lengths)
if max_input_tokens is None or max_input_tokens == 0:
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
del _batch, batch

View File

@ -16,7 +16,8 @@ from text_generation_server.models.globals import PREFIX_CACHING
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
tracer = trace.get_tracer(__name__)
@ -447,6 +448,10 @@ class FlashVlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
@ -459,13 +464,15 @@ class FlashVlmCausalLM(FlashCausalLM):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

View File

@ -17,8 +17,8 @@ from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
tracer = trace.get_tracer(__name__)
@ -279,6 +279,10 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states = batch.cross_attention_states
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -286,13 +290,15 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
# TODO list
adapter_data=None,
image_indices=batch.image_indices[:],
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

View File

@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
if max_batch_size_str is not None:
MAX_BATCH_SIZE = int(max_batch_size_str)
else:
raise ValueError("MAX_BATCH_SIZE is not set")
PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = []
@ -1464,6 +1460,11 @@ class VlmCausalLM(Model):
batch = self.batch_from_pb(request.batch, is_warmup=True)
max_input_tokens = request.max_input_tokens
max_prefill_batch_size = batch.input_ids.shape[0]
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
if max_batch_size_str is not None:
MAX_BATCH_SIZE = int(max_batch_size_str)
else:
raise ValueError("MAX_BATCH_SIZE is not set")
try:
# max prefill batch size warmup

View File

@ -18,10 +18,11 @@ from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
from text_generation_server.models.globals import set_model_id, ATTENTION
from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -109,14 +110,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(request)
)
if ATTENTION == "paged":
set_max_prefill_tokens(request.max_prefill_tokens)
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.dtype,
self.model.device,
)
# W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup,
# because router is not aware of that feature
make_tokenizer_optional(self.model.tokenizer)
# Override default values with None for clearer semantics.
max_input_tokens = (
request.max_input_tokens
if request.HasField("max_input_tokens")
else None
)
max_total_tokens = (
request.max_total_tokens
if request.HasField("max_total_tokens")
else None
)
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(batch, max_input_tokens, max_total_tokens)
)
else:
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(request)
)
# W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup,
# because router is not aware of that feature
make_tokenizer_optional(self.model.tokenizer)
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens,

View File

@ -0,0 +1,24 @@
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS