mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
Add mark_step into llama4
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
ad41abd68c
commit
2e8d3e91ea
@ -22,6 +22,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module):
|
||||
)
|
||||
|
||||
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module):
|
||||
position_ids=position_ids,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user