diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 4e83ba3c..98994e48 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -291,13 +291,11 @@ class Llama4TextRotaryEmbedding(nn.Module): position_ids_expanded = position_ids_expanded.to(device_type) with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - cos = torch.cos(freqs) * self.attention_scaling - sin = torch.sin(freqs) * self.attention_scaling - cos = cos.reshape(-1, 1, cos.shape[-1]) - sin = sin.reshape(-1, 1, sin.shape[-1]) - freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling - freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) - return freqs_cis + freqs_cis = ( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + * self.attention_scaling + ) + return freqs_cis.to(dtype=x.dtype, device=x.device) class Llama4TextAttention(FlashLlamaAttention):