import os import math import torch from torch import nn from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) return inv_freq def _get_rope_config(config): if os.getenv("ROPE_SCALING", None) is not None: rope_scaling = { "type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"]), } return rope_scaling return getattr(config, "rope_scaling", None) class PositionRotaryEmbedding(nn.Module): def __init__(self, inv_freq, scaling_factor, max_position_embeddings): super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None self._update_cos_sin_cache( torch.float32, inv_freq.device, max_position_embeddings ) def forward( self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ): num_tokens = query.shape[0] head_size = query.shape[-1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal # to query hidden dimension, so the original tensors need to be # expanded # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE # and expansion of cos/sin tensors via concatenation rope_mode = RotaryPosEmbeddingMode.BLOCKWISE cos = torch.cat((cos, cos), dim=-1) sin = torch.cat((sin, sin), dim=-1) rotary_dim = cos.shape[-1] query_shape = query.shape query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) key_shape = key.shape key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) if not hasattr(config, "max_position_embeddings") and hasattr( config, "max_seq_len" ): # handling for dbrx config.max_position_embeddings = config.max_seq_len if rope_scaling is not None: # `rope_type` is now standard in transformers, but some existing models # have `type` instead. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) if rope_type == "linear": pass elif rope_type == "default": pass elif rope_type == "mrope": mrope_section = rope_scaling["mrope_section"] if mrope_section is not None: return RotaryPositionEmbeddingMultimodalSections( inv_freq, scaling_factor, mrope_section, config.max_position_embeddings, ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor, ) elif rope_type == "llama3": inv_freq = apply_llama3_scaling( inv_freq, scaling_factor=rope_scaling["factor"], low_freq_factor=rope_scaling["low_freq_factor"], high_freq_factor=rope_scaling["high_freq_factor"], original_max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], ) return cls(inv_freq, scaling_factor, config.max_position_embeddings) elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] mscale = rope_scaling.get("mscale", 1.0) mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, mscale=mscale, mscale_all_dim=mscale_all_dim, ) elif rope_type in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) short_inv_freq = 1.0 / ( short_factor * base ** ( torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim ) ) long_factor = torch.tensor( rope_scaling["long_factor"], dtype=torch.float32, device=device ) long_inv_freq = 1.0 / ( long_factor * base ** ( torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim ) ) original_max_position_embeddings = ( config.original_max_position_embeddings ) max_position_embeddings = config.max_position_embeddings if max_position_embeddings <= original_max_position_embeddings: scaling_factor = 1.0 else: scale = max_position_embeddings / original_max_position_embeddings scaling_factor = math.sqrt( 1 + math.log(scale) / math.log(original_max_position_embeddings) ) # if short_mscale and long_mscale are provided we need to scale the freqs # using the Phi3LongRoPEScaledRotaryEmbedding if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling): short_mscale = rope_scaling["short_mscale"] long_mscale = rope_scaling["long_mscale"] return Phi3LongRoPEScaledRotaryEmbedding( short_inv_freq=short_inv_freq, long_inv_freq=long_inv_freq, max_position_embeddings=config.max_position_embeddings, short_mscale=short_mscale, long_mscale=long_mscale, original_max_position_embeddings=original_max_position_embeddings, ) return SuRotaryEmbedding( short_inv_freq=short_inv_freq, long_inv_freq=long_inv_freq, scaling_factor=scaling_factor, original_max_position_embeddings=original_max_position_embeddings, max_position_embeddings=config.max_position_embeddings, ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) return cls(inv_freq, scaling_factor, config.max_position_embeddings) @classmethod def load(cls, config, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype scaling_factor = None rope_scaling = _get_rope_config(config) if rope_scaling is not None: scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "linear": pass elif rope_scaling["type"] == "dynamic": return DynamicPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": mscale = rope_scaling.get("mscale", 1.0) mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, mscale=mscale, mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" ) return cls(inv_freq, scaling_factor, config.max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) if self.scaling_factor is not None: t /= self.scaling_factor # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) class SuRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, short_inv_freq, long_inv_freq, scaling_factor, original_max_position_embeddings, max_position_embeddings, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq self.long_inv_freq = long_inv_freq self.scaling_factor = scaling_factor self.original_max_position_embeddings = original_max_position_embeddings self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None self._update_cos_sin_cache( torch.float32, short_inv_freq.device, max_position_embeddings ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) short_freqs = torch.outer( t[: self.original_max_position_embeddings], self.short_inv_freq.to(device=t.device), ) long_freqs = torch.outer( t[self.original_max_position_embeddings :], self.long_inv_freq.to(device=t.device), ) freqs = torch.cat([short_freqs, long_freqs]) self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, short_inv_freq: torch.Tensor, long_inv_freq: torch.Tensor, max_position_embeddings: int, short_mscale: float, long_mscale: float, original_max_position_embeddings: int, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq self.long_inv_freq = long_inv_freq self.max_position_embeddings = max_position_embeddings self.short_mscale = short_mscale self.long_mscale = long_mscale self.original_max_position_embeddings = original_max_position_embeddings # cache self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None self._update_cos_sin_cache( torch.float32, short_inv_freq.device, max_position_embeddings ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) short_freqs = torch.outer( t[: self.original_max_position_embeddings], self.short_inv_freq.to(device=t.device), ) long_freqs = torch.outer( t[self.original_max_position_embeddings :], self.long_inv_freq.to(device=t.device), ) short_freqs = short_freqs * self.short_mscale long_freqs = long_freqs * self.long_mscale freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device) freqs[: self.original_max_position_embeddings] = short_freqs freqs[self.original_max_position_embeddings :] = long_freqs self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor, max_position_embeddings) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: newbase = self.base * ( (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) self.inv_freq = _create_inv_freq( self.dim, newbase, self.inv_freq.device ) self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) # Find dim range bounds based on rotations def find_correction_range( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 ): low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case def linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__( self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor, attn_factor, beta_fast, beta_slow, mscale: float, mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__( inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale_all_dim = mscale_all_dim self.scaling_factor = scaling_factor self.mscale = float( get_mscale(self.scaling_factor, mscale) / get_mscale(self.scaling_factor, mscale_all_dim) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) freqs = 1.0 / inv_freq_extrapolation inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) low, high = find_correction_range( self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings, ) inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation inv_freq = ( inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask ) self.inv_freq = inv_freq self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) def apply_llama3_scaling( freqs: torch.Tensor, *, scaling_factor: int, low_freq_factor: int, high_freq_factor: int, original_max_position_embeddings: int, ): low_freq_wavelen = original_max_position_embeddings / low_freq_factor high_freq_wavelen = original_max_position_embeddings / high_freq_factor new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < high_freq_wavelen: new_freqs.append(freq) elif wavelen > low_freq_wavelen: new_freqs.append(freq / scaling_factor) else: assert low_freq_wavelen != high_freq_wavelen smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( high_freq_factor - low_freq_factor ) new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): def __init__( self, inv_freq: torch.Tensor, scaling_factor: float, sections: list, max_position_embeddings, ): self.sections = sections self._cos_cached = None self._sin_cached = None self.section_indices = ( torch.arange(len(self.sections)) .repeat_interleave(torch.tensor(self.sections)) .view(1, 1, -1) .to(inv_freq.device) ) super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache( self, dtype: torch.dtype, device: torch.device, seqlen: int ): # always cache the cos/sin for the full sequence length to avoid # recomputing if the sequence length is smaller than the cached one if ( seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) self._sections = self.section_indices.expand(seqlen, -1, -1) def get_cos_sin( self, position_ids: torch.Tensor, ): slen = position_ids.shape[0] cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) return cos, sin