mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
Add mark_step into llama4
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
ad41abd68c
commit
2e8d3e91ea
@ -22,6 +22,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
||||||
|
lazy_mode = htorch.utils.internal.is_lazy()
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
if lazy_mode:
|
||||||
|
htorch.core.mark_step()
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user