mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix the RotaryEmbedding
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
f0acbbf10c
commit
9414bcca0f
@ -291,13 +291,11 @@ class Llama4TextRotaryEmbedding(nn.Module):
|
|||||||
position_ids_expanded = position_ids_expanded.to(device_type)
|
position_ids_expanded = position_ids_expanded.to(device_type)
|
||||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||||
cos = torch.cos(freqs) * self.attention_scaling
|
freqs_cis = (
|
||||||
sin = torch.sin(freqs) * self.attention_scaling
|
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||||
cos = cos.reshape(-1, 1, cos.shape[-1])
|
* self.attention_scaling
|
||||||
sin = sin.reshape(-1, 1, sin.shape[-1])
|
)
|
||||||
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
|
return freqs_cis.to(dtype=x.dtype, device=x.device)
|
||||||
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextAttention(FlashLlamaAttention):
|
class Llama4TextAttention(FlashLlamaAttention):
|
||||||
|
Loading…
Reference in New Issue
Block a user