mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
accelerate warmup
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b2bd163d19
commit
9281be20c0
@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
from vllm_hpu_extension.kernels import rms_norm
|
|
||||||
|
|
||||||
orig_shape = hidden_states.shape
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
residual += hidden_states.view(residual.shape)
|
hidden_states += residual
|
||||||
else:
|
residual = hidden_states
|
||||||
residual = hidden_states
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
if len(orig_shape) == 2:
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
residual = residual.unsqueeze(0)
|
return self.weight * hidden_states.to(self.weight.dtype), residual
|
||||||
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
|
||||||
return x.view(orig_shape), residual.view(orig_shape)
|
|
||||||
|
@ -51,6 +51,8 @@ from habana_frameworks.torch.hpex.kernels import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class CohereRotary(PositionRotaryEmbedding):
|
class CohereRotary(PositionRotaryEmbedding):
|
||||||
def forward(
|
def forward(
|
||||||
@ -420,7 +422,9 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -433,6 +437,8 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class DbrxAttentionConfig(PretrainedConfig):
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
@ -682,8 +683,10 @@ class DbrxModel(torch.nn.Module):
|
|||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -696,6 +699,8 @@ class DbrxModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
|||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Config(PretrainedConfig):
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
@ -575,6 +576,9 @@ class DeepseekV2Model(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -587,6 +591,8 @@ class DeepseekV2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
|||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
|
||||||
@ -651,6 +652,9 @@ class DeepseekV3Model(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -663,6 +667,8 @@ class DeepseekV3Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Config(PretrainedConfig):
|
class Gemma2Config(PretrainedConfig):
|
||||||
@ -472,6 +473,10 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -485,6 +490,8 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
@ -394,6 +395,9 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -406,6 +410,8 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||||
@ -385,6 +386,10 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -395,6 +400,8 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
|
|||||||
RotaryPosEmbeddingMode,
|
RotaryPosEmbeddingMode,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
@ -330,6 +331,9 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -342,6 +346,8 @@ class FlashGPTJModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
KVCache,
|
KVCache,
|
||||||
get_kv_scales,
|
get_kv_scales,
|
||||||
@ -554,6 +554,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -568,6 +571,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
@ -401,6 +402,9 @@ class MistralModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -414,6 +418,8 @@ class MistralModel(torch.nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -44,6 +44,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
|||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
@ -452,6 +453,9 @@ class MixtralModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -464,6 +468,8 @@ class MixtralModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||||
@ -360,6 +361,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -372,6 +376,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
@ -353,6 +354,9 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -365,6 +369,8 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
@ -294,6 +295,9 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -306,6 +310,8 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
@ -634,6 +635,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -646,6 +650,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
@ -442,6 +443,9 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -452,6 +456,8 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ from text_generation_server.layers.rotary import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
@ -517,6 +518,9 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -530,6 +534,8 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -1596,13 +1596,17 @@ class FlashCausalLM(Model):
|
|||||||
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
|
||||||
HPUBucketingContext = get_bucketing_context()
|
HPUBucketingContext = get_bucketing_context()
|
||||||
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
|
||||||
|
model_max_length = self.tokenizer.model_max_length
|
||||||
|
max_position_embeddings = getattr(
|
||||||
|
self.config, "max_position_embeddings", model_max_length
|
||||||
|
)
|
||||||
self.bucketing_ctx = HPUBucketingContext(
|
self.bucketing_ctx = HPUBucketingContext(
|
||||||
max_num_seqs,
|
max_num_seqs,
|
||||||
max_num_seqs, # self.max_num_prefill_seqs, #TODO
|
max_num_seqs, # self.max_num_prefill_seqs, #TODO
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_num_seqs * max_total_tokens_aligned,
|
max_num_seqs * max_total_tokens_aligned,
|
||||||
False,
|
False,
|
||||||
self.tokenizer.model_max_length,
|
min(model_max_length, max_position_embeddings),
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens_aligned,
|
max_total_tokens_aligned,
|
||||||
)
|
)
|
||||||
@ -1631,30 +1635,48 @@ class FlashCausalLM(Model):
|
|||||||
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
|
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch):
|
def warmup_hpu_graph(self, batch):
|
||||||
|
start_time = time.time()
|
||||||
|
warmup_shape_count = 0
|
||||||
warmup_times = 3
|
warmup_times = 3
|
||||||
self.bucketing_ctx.generate_prompt_buckets()
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
for i, (batch_size, seq_len) in enumerate(
|
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
def ordering_function_min_tokens(b):
|
||||||
):
|
return (b[0] * b[1], b[1], b[0])
|
||||||
|
|
||||||
|
buckets = list(
|
||||||
|
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
continue
|
continue
|
||||||
|
warmup_shape_count += 1
|
||||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
|
def ordering_function_max_bs(b):
|
||||||
|
return (-b[0], b[1])
|
||||||
|
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
buckets = list(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||||
):
|
)
|
||||||
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
|
warmup_shape_count += 1
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
)
|
||||||
|
|
||||||
def warmup_prefill(
|
def warmup_prefill(
|
||||||
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||||
|
@ -23,6 +23,7 @@ from text_generation_server.layers.attention import (
|
|||||||
_async_h2d_tensor_copy,
|
_async_h2d_tensor_copy,
|
||||||
)
|
)
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
import time
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
synchronize,
|
synchronize,
|
||||||
)
|
)
|
||||||
@ -440,20 +441,32 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||||
|
start_time = time.time()
|
||||||
|
warmup_shape_count = 0
|
||||||
warmup_times = 3
|
warmup_times = 3
|
||||||
|
|
||||||
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
|
||||||
|
def ordering_function_max_bs(b):
|
||||||
|
return (-b[0], b[1])
|
||||||
|
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
buckets = list(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||||
):
|
)
|
||||||
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
|
warmup_shape_count += 1
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -32,6 +32,7 @@ from text_generation_server.utils.import_utils import (
|
|||||||
)
|
)
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
import time
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -345,29 +346,47 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||||
|
start_time = time.time()
|
||||||
|
warmup_shape_count = 0
|
||||||
warmup_times = 3
|
warmup_times = 3
|
||||||
self.bucketing_ctx.generate_prompt_buckets()
|
self.bucketing_ctx.generate_prompt_buckets()
|
||||||
for i, (batch_size, seq_len) in enumerate(
|
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
def ordering_function_min_tokens(b):
|
||||||
):
|
return (b[0] * b[1], b[1], b[0])
|
||||||
|
|
||||||
|
buckets = list(
|
||||||
|
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
|
||||||
|
)
|
||||||
|
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
continue
|
continue
|
||||||
|
warmup_shape_count += 1
|
||||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
|
def ordering_function_max_bs(b):
|
||||||
|
return (-b[0], b[1])
|
||||||
|
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
buckets = list(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
|
||||||
):
|
)
|
||||||
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
|
warmup_shape_count += 1
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user