mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix the RotaryEmbedding
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
f0acbbf10c
commit
9414bcca0f
@ -291,13 +291,11 @@ class Llama4TextRotaryEmbedding(nn.Module):
|
||||
position_ids_expanded = position_ids_expanded.to(device_type)
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||
cos = torch.cos(freqs) * self.attention_scaling
|
||||
sin = torch.sin(freqs) * self.attention_scaling
|
||||
cos = cos.reshape(-1, 1, cos.shape[-1])
|
||||
sin = sin.reshape(-1, 1, sin.shape[-1])
|
||||
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
|
||||
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||
return freqs_cis
|
||||
freqs_cis = (
|
||||
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||
* self.attention_scaling
|
||||
)
|
||||
return freqs_cis.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
class Llama4TextAttention(FlashLlamaAttention):
|
||||
|
Loading…
Reference in New Issue
Block a user