Fix the RotaryEmbedding

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 23:39:42 +00:00
parent f0acbbf10c
commit 9414bcca0f

View File

@ -291,13 +291,11 @@ class Llama4TextRotaryEmbedding(nn.Module):
position_ids_expanded = position_ids_expanded.to(device_type)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
cos = torch.cos(freqs) * self.attention_scaling
sin = torch.sin(freqs) * self.attention_scaling
cos = cos.reshape(-1, 1, cos.shape[-1])
sin = sin.reshape(-1, 1, sin.shape[-1])
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
return freqs_cis
freqs_cis = (
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
* self.attention_scaling
)
return freqs_cis.to(dtype=x.dtype, device=x.device)
class Llama4TextAttention(FlashLlamaAttention):