accelerate warmup

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-14 19:16:00 -07:00
parent b2bd163d19
commit 9281be20c0
22 changed files with 181 additions and 31 deletions

View File

@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None: if residual is not None:
residual += hidden_states.view(residual.shape) hidden_states += residual
else: residual = hidden_states
residual = hidden_states hidden_states = hidden_states.to(torch.float32)
# Note: HPUFusedRMSNorm requires 3D tensors as inputs variance = hidden_states.pow(2).mean(-1, keepdim=True)
if len(orig_shape) == 2: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
residual = residual.unsqueeze(0) return self.weight * hidden_states.to(self.weight.dtype), residual
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)

View File

@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
class CohereRotary(PositionRotaryEmbedding): class CohereRotary(PositionRotaryEmbedding):
def forward( def forward(
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -41,6 +41,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
@ -651,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -663,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)

View File

@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, RotaryPosEmbeddingMode,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -26,7 +26,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
get_kv_scales, get_kv_scales,
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
cross_attention_states, cross_attention_states,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states

View File

@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig): class GPTNeoXConfig(TransformersGPTNeoXConfig):
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
import habana_frameworks.torch as htorch
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -18,7 +18,6 @@
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
import habana_frameworks.torch as htorch
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
) )
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer( hidden_states = layer(
hidden_states, hidden_states,
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)

View File

@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
import habana_frameworks.torch as htorch
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
import habana_frameworks.torch as htorch
def load_multi_mqa( def load_multi_mqa(
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
adapter_data, adapter_data,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -1596,13 +1596,17 @@ class FlashCausalLM(Model):
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context() HPUBucketingContext = get_bucketing_context()
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
self.bucketing_ctx = HPUBucketingContext( self.bucketing_ctx = HPUBucketingContext(
max_num_seqs, max_num_seqs,
max_num_seqs, # self.max_num_prefill_seqs, #TODO max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE, BLOCK_SIZE,
max_num_seqs * max_total_tokens_aligned, max_num_seqs * max_total_tokens_aligned,
False, False,
self.tokenizer.model_max_length, min(model_max_length, max_position_embeddings),
max_input_tokens, max_input_tokens,
max_total_tokens_aligned, max_total_tokens_aligned,
) )
@ -1631,30 +1635,48 @@ class FlashCausalLM(Model):
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
def warmup_hpu_graph(self, batch): def warmup_hpu_graph(self, batch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device) synchronize(self.device)
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill( def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch

View File

@ -23,6 +23,7 @@ from text_generation_server.layers.attention import (
_async_h2d_tensor_copy, _async_h2d_tensor_copy,
) )
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import time
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
synchronize, synchronize,
) )
@ -440,20 +441,32 @@ class FlashVlmCausalLM(FlashCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless # only warmup decode, for prefill, image pixal size may change, make the warmup useless
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,

View File

@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import (
) )
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
import time
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -345,29 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) def ordering_function_min_tokens(b):
): return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
warmup_shape_count += 1
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device) synchronize(self.device)
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( buckets = list(
reversed(self.bucketing_ctx.decode_buckets) sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
): )
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
warmup_shape_count += 1
log_master( log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}" logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
) )
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward( def forward(
self, self,