From 9281be20c07533dd9aadc0909237ddb25774ba7c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 14 May 2025 19:16:00 -0700 Subject: [PATCH] accelerate warmup Signed-off-by: Wang, Yi A --- .../layers/layernorm.py | 17 ++++----- .../custom_modeling/flash_cohere_modeling.py | 8 ++++- .../custom_modeling/flash_dbrx_modeling.py | 7 +++- .../flash_deepseek_v2_modeling.py | 6 ++++ .../flash_deepseek_v3_modeling.py | 6 ++++ .../custom_modeling/flash_gemma2_modeling.py | 7 ++++ .../custom_modeling/flash_gemma_modeling.py | 6 ++++ .../custom_modeling/flash_gpt2_modeling.py | 7 ++++ .../custom_modeling/flash_gptj_modeling.py | 6 ++++ .../custom_modeling/flash_llama_modeling.py | 7 +++- .../custom_modeling/flash_mistral_modeling.py | 6 ++++ .../custom_modeling/flash_mixtral_modeling.py | 6 ++++ .../custom_modeling/flash_neox_modeling.py | 6 ++++ .../custom_modeling/flash_phi_modeling.py | 6 ++++ .../custom_modeling/flash_phi_moe_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 6 ++++ .../custom_modeling/flash_rw_modeling.py | 6 ++++ .../flash_santacoder_modeling.py | 6 ++++ .../flash_starcoder2_modeling.py | 6 ++++ .../models/flash_causal_lm.py | 36 +++++++++++++++---- .../models/flash_vlm_causal_lm.py | 19 ++++++++-- .../models/mllama_causal_lm.py | 31 ++++++++++++---- 22 files changed, 181 insertions(+), 31 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index 84878791..4bbb6c1f 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - from vllm_hpu_extension.kernels import rms_norm - - orig_shape = hidden_states.shape if residual is not None: - residual += hidden_states.view(residual.shape) - else: - residual = hidden_states - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - if len(orig_shape) == 2: - residual = residual.unsqueeze(0) - x = rms_norm().apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual.view(orig_shape) + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(self.weight.dtype), residual diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 3bcc689d..801ae09e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import ( apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch + class CohereRotary(PositionRotaryEmbedding): def forward( @@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None - + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 15c243c9..76972d38 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, ) from vllm_hpu_extension.ops import DynamicFusedMOE +import habana_frameworks.torch as htorch class DbrxAttentionConfig(PretrainedConfig): @@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) - residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 9d61c694..6ac7fc1a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch class DeepseekV2Config(PretrainedConfig): @@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index f6620d51..e0481691 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -41,6 +41,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: @@ -651,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -663,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 79f21b0f..a5860823 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Gemma2Config(PretrainedConfig): @@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 609f03ac..3d678df1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GemmaConfig(PretrainedConfig): @@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 10024a6d..ed413662 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales +import habana_frameworks.torch as htorch def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module): hidden_states = inputs_embeds residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 41eeab78..cde03a00 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix: str, weights): @@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 81af5560..0edea03a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN - +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( KVCache, get_kv_scales, @@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module): cross_attention_states, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d23d4f67..75d9d360 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch class MistralConfig(PretrainedConfig): @@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1ef6be48..f47986d8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class MixtralConfig(PretrainedConfig): @@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33f63333..29620826 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GPTNeoXConfig(TransformersGPTNeoXConfig): @@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.final_layer_norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 0c777912..12830991 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +import habana_frameworks.torch as htorch class PhiConfig(PretrainedConfig): @@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py index bb585cc4..c28f3aee 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -18,7 +18,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index af4b404d..7c7ac03e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix, weights): @@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 141e13a6..76a2cd01 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -21,6 +21,7 @@ from text_generation_server.layers.attention import ( Seqlen, HPUPagedAttentionMetadata, ) +import habana_frameworks.torch as htorch def load_row(config, prefix: str, weights, bias: bool): @@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel): cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, @@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b68f4784..c64b2ff7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +import habana_frameworks.torch as htorch def load_multi_mqa( @@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 76f6f473..94c60eb6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Starcoder2Config(PretrainedConfig): @@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 8bbd46b5..f5031d6f 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1596,13 +1596,17 @@ class FlashCausalLM(Model): max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) HPUBucketingContext = get_bucketing_context() max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE + model_max_length = self.tokenizer.model_max_length + max_position_embeddings = getattr( + self.config, "max_position_embeddings", model_max_length + ) self.bucketing_ctx = HPUBucketingContext( max_num_seqs, max_num_seqs, # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned, False, - self.tokenizer.model_max_length, + min(model_max_length, max_position_embeddings), max_input_tokens, max_total_tokens_aligned, ) @@ -1631,30 +1635,48 @@ class FlashCausalLM(Model): return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture def warmup_hpu_graph(self, batch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + + for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) synchronize(self.device) + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index a1a7ca4d..13393051 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -23,6 +23,7 @@ from text_generation_server.layers.attention import ( _async_h2d_tensor_copy, ) import habana_frameworks.torch as htorch +import time from text_generation_server.utils.import_utils import ( synchronize, ) @@ -440,20 +441,32 @@ class FlashVlmCausalLM(FlashCausalLM): ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index dac65fea..0e5544f2 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import ( ) import torch.nn.functional as F from text_generation_server.utils.log import log_master +import time tracer = trace.get_tracer(__name__) @@ -345,29 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + for i, (batch_size, seq_len) in enumerate(buckets): if batch_size * seq_len > self.max_batch_prefill_tokens: continue + warmup_shape_count += 1 log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) synchronize(self.device) + + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue + warmup_shape_count += 1 log_master( logger.info, f"warmup decode bs {batch_size} block_num {block_num}" ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self,