mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
607 lines
24 KiB
Python
607 lines
24 KiB
Python
import os
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
from habana_frameworks.torch.hpex.kernels import (
|
|
RotaryPosEmbeddingMode,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
|
|
def _create_inv_freq(dim, base, device):
|
|
inv_freq = 1.0 / (
|
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
|
)
|
|
return inv_freq
|
|
|
|
|
|
def _get_rope_config(config):
|
|
if os.getenv("ROPE_SCALING", None) is not None:
|
|
rope_scaling = {
|
|
"type": os.environ["ROPE_SCALING"],
|
|
"factor": float(os.environ["ROPE_FACTOR"]),
|
|
}
|
|
return rope_scaling
|
|
return getattr(config, "rope_scaling", None)
|
|
|
|
|
|
class PositionRotaryEmbedding(nn.Module):
|
|
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
|
|
super().__init__()
|
|
self.inv_freq = inv_freq
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
self.scaling_factor = scaling_factor
|
|
self.dynamic_args = None
|
|
self._update_cos_sin_cache(
|
|
torch.float32, inv_freq.device, max_position_embeddings
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
):
|
|
num_tokens = query.shape[0]
|
|
head_size = query.shape[-1]
|
|
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
|
# to query hidden dimension, so the original tensors need to be
|
|
# expanded
|
|
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
|
# and expansion of cos/sin tensors via concatenation
|
|
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
|
cos = torch.cat((cos, cos), dim=-1)
|
|
sin = torch.cat((sin, sin), dim=-1)
|
|
rotary_dim = cos.shape[-1]
|
|
query_shape = query.shape
|
|
query = query.view(num_tokens, -1, head_size)
|
|
query_rot = query[..., :rotary_dim]
|
|
query_pass = query[..., rotary_dim:]
|
|
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
|
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
|
|
|
key_shape = key.shape
|
|
key = key.view(num_tokens, -1, head_size)
|
|
key_rot = key[..., :rotary_dim]
|
|
key_pass = key[..., rotary_dim:]
|
|
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
|
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
|
|
|
@classmethod
|
|
def static(cls, config, dim, base, device):
|
|
inv_freq = _create_inv_freq(dim, base, device)
|
|
scaling_factor = None
|
|
rope_scaling = _get_rope_config(config)
|
|
if not hasattr(config, "max_position_embeddings") and hasattr(
|
|
config, "max_seq_len"
|
|
):
|
|
# handling for dbrx
|
|
config.max_position_embeddings = config.max_seq_len
|
|
if rope_scaling is not None:
|
|
# `rope_type` is now standard in transformers, but some existing models
|
|
# have `type` instead.
|
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
|
|
|
if rope_type == "linear":
|
|
pass
|
|
elif rope_type == "default":
|
|
pass
|
|
elif rope_type == "mrope":
|
|
mrope_section = rope_scaling["mrope_section"]
|
|
if mrope_section is not None:
|
|
return RotaryPositionEmbeddingMultimodalSections(
|
|
inv_freq,
|
|
scaling_factor,
|
|
mrope_section,
|
|
config.max_position_embeddings,
|
|
)
|
|
elif rope_type == "dynamic":
|
|
scaling_factor = rope_scaling["factor"]
|
|
return DynamicPositionRotaryEmbedding(
|
|
dim=dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=base,
|
|
device=inv_freq.device,
|
|
scaling_factor=scaling_factor,
|
|
)
|
|
elif rope_type == "llama3":
|
|
inv_freq = apply_llama3_scaling(
|
|
inv_freq,
|
|
scaling_factor=rope_scaling["factor"],
|
|
low_freq_factor=rope_scaling["low_freq_factor"],
|
|
high_freq_factor=rope_scaling["high_freq_factor"],
|
|
original_max_position_embeddings=rope_scaling[
|
|
"original_max_position_embeddings"
|
|
],
|
|
)
|
|
|
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
|
|
|
elif rope_type == "yarn":
|
|
scaling_factor = rope_scaling["factor"]
|
|
mscale = rope_scaling.get("mscale", 1.0)
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
|
return YarnPositionRotaryEmbedding(
|
|
dim=2 * inv_freq.shape[0],
|
|
max_position_embeddings=rope_scaling[
|
|
"original_max_position_embeddings"
|
|
],
|
|
base=base,
|
|
device=inv_freq.device,
|
|
scaling_factor=scaling_factor,
|
|
extrapolation_factor=1,
|
|
attn_factor=1,
|
|
beta_fast=32,
|
|
beta_slow=1,
|
|
mscale=mscale,
|
|
mscale_all_dim=mscale_all_dim,
|
|
)
|
|
elif rope_type in ["su", "longrope"]:
|
|
short_factor = torch.tensor(
|
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
|
)
|
|
short_inv_freq = 1.0 / (
|
|
short_factor
|
|
* base
|
|
** (
|
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
|
/ dim
|
|
)
|
|
)
|
|
long_factor = torch.tensor(
|
|
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
|
)
|
|
long_inv_freq = 1.0 / (
|
|
long_factor
|
|
* base
|
|
** (
|
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
|
/ dim
|
|
)
|
|
)
|
|
|
|
original_max_position_embeddings = (
|
|
config.original_max_position_embeddings
|
|
)
|
|
max_position_embeddings = config.max_position_embeddings
|
|
if max_position_embeddings <= original_max_position_embeddings:
|
|
scaling_factor = 1.0
|
|
else:
|
|
scale = max_position_embeddings / original_max_position_embeddings
|
|
scaling_factor = math.sqrt(
|
|
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
|
)
|
|
|
|
# if short_mscale and long_mscale are provided we need to scale the freqs
|
|
# using the Phi3LongRoPEScaledRotaryEmbedding
|
|
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
|
|
short_mscale = rope_scaling["short_mscale"]
|
|
long_mscale = rope_scaling["long_mscale"]
|
|
return Phi3LongRoPEScaledRotaryEmbedding(
|
|
short_inv_freq=short_inv_freq,
|
|
long_inv_freq=long_inv_freq,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
short_mscale=short_mscale,
|
|
long_mscale=long_mscale,
|
|
original_max_position_embeddings=original_max_position_embeddings,
|
|
)
|
|
|
|
return SuRotaryEmbedding(
|
|
short_inv_freq=short_inv_freq,
|
|
long_inv_freq=long_inv_freq,
|
|
scaling_factor=scaling_factor,
|
|
original_max_position_embeddings=original_max_position_embeddings,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
|
)
|
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
|
|
|
@classmethod
|
|
def load(cls, config, prefix, weights):
|
|
# XXX: Always load this in float32 !
|
|
dtype = weights.dtype
|
|
weights.dtype = torch.float32
|
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
|
weights.dtype = dtype
|
|
|
|
scaling_factor = None
|
|
rope_scaling = _get_rope_config(config)
|
|
if rope_scaling is not None:
|
|
scaling_factor = rope_scaling["factor"]
|
|
if rope_scaling["type"] == "linear":
|
|
pass
|
|
elif rope_scaling["type"] == "dynamic":
|
|
return DynamicPositionRotaryEmbedding(
|
|
dim=2 * inv_freq.shape[0],
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=10000.0,
|
|
device=inv_freq.device,
|
|
scaling_factor=scaling_factor,
|
|
)
|
|
elif rope_scaling["type"] == "yarn":
|
|
mscale = rope_scaling.get("mscale", 1.0)
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
|
return YarnPositionRotaryEmbedding(
|
|
dim=2 * inv_freq.shape[0],
|
|
max_position_embeddings=rope_scaling[
|
|
"original_max_position_embeddings"
|
|
],
|
|
base=10000.0,
|
|
device=inv_freq.device,
|
|
scaling_factor=scaling_factor,
|
|
extrapolation_factor=1,
|
|
attn_factor=1,
|
|
beta_fast=32,
|
|
beta_slow=1,
|
|
mscale=mscale,
|
|
mscale_all_dim=mscale_all_dim,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
|
)
|
|
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
|
|
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
if self.scaling_factor is not None:
|
|
t /= self.scaling_factor
|
|
# Don't do einsum, it converts fp32 to fp16
|
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
|
def get_cos_sin(self, position_ids: torch.Tensor):
|
|
|
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
|
|
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
|
|
|
|
|
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
short_inv_freq,
|
|
long_inv_freq,
|
|
scaling_factor,
|
|
original_max_position_embeddings,
|
|
max_position_embeddings,
|
|
):
|
|
super(PositionRotaryEmbedding, self).__init__()
|
|
self.short_inv_freq = short_inv_freq
|
|
self.long_inv_freq = long_inv_freq
|
|
self.scaling_factor = scaling_factor
|
|
self.original_max_position_embeddings = original_max_position_embeddings
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
self.dynamic_args = None
|
|
self._update_cos_sin_cache(
|
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
|
)
|
|
|
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached is None
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
|
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
|
short_freqs = torch.outer(
|
|
t[: self.original_max_position_embeddings],
|
|
self.short_inv_freq.to(device=t.device),
|
|
)
|
|
long_freqs = torch.outer(
|
|
t[self.original_max_position_embeddings :],
|
|
self.long_inv_freq.to(device=t.device),
|
|
)
|
|
|
|
freqs = torch.cat([short_freqs, long_freqs])
|
|
|
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
|
|
|
|
|
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
short_inv_freq: torch.Tensor,
|
|
long_inv_freq: torch.Tensor,
|
|
max_position_embeddings: int,
|
|
short_mscale: float,
|
|
long_mscale: float,
|
|
original_max_position_embeddings: int,
|
|
):
|
|
super(PositionRotaryEmbedding, self).__init__()
|
|
self.short_inv_freq = short_inv_freq
|
|
self.long_inv_freq = long_inv_freq
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.short_mscale = short_mscale
|
|
self.long_mscale = long_mscale
|
|
self.original_max_position_embeddings = original_max_position_embeddings
|
|
|
|
# cache
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
self.dynamic_args = None
|
|
self._update_cos_sin_cache(
|
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
|
)
|
|
|
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached is None
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
|
|
|
short_freqs = torch.outer(
|
|
t[: self.original_max_position_embeddings],
|
|
self.short_inv_freq.to(device=t.device),
|
|
)
|
|
|
|
long_freqs = torch.outer(
|
|
t[self.original_max_position_embeddings :],
|
|
self.long_inv_freq.to(device=t.device),
|
|
)
|
|
|
|
short_freqs = short_freqs * self.short_mscale
|
|
long_freqs = long_freqs * self.long_mscale
|
|
|
|
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
|
|
freqs[: self.original_max_position_embeddings] = short_freqs
|
|
freqs[self.original_max_position_embeddings :] = long_freqs
|
|
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
|
|
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
|
inv_freq = _create_inv_freq(dim, base, device)
|
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
|
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
if seqlen > self.max_position_embeddings:
|
|
newbase = self.base * (
|
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
|
- (self.scaling_factor - 1)
|
|
) ** (self.dim / (self.dim - 2))
|
|
self.inv_freq = _create_inv_freq(
|
|
self.dim, newbase, self.inv_freq.device
|
|
)
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
# Don't do einsum, it converts fp32 to fp16
|
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
|
|
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
|
2 * math.log(base)
|
|
)
|
|
|
|
|
|
# Find dim range bounds based on rotations
|
|
def find_correction_range(
|
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
|
):
|
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
|
|
|
|
def linear_ramp_mask(min, max, dim):
|
|
if min == max:
|
|
max += 0.001 # Prevent singularity
|
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
return ramp_func
|
|
|
|
|
|
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
|
|
if scale <= 1:
|
|
return 1.0
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
|
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_position_embeddings,
|
|
base,
|
|
device,
|
|
scaling_factor,
|
|
*,
|
|
extrapolation_factor,
|
|
attn_factor,
|
|
beta_fast,
|
|
beta_slow,
|
|
mscale: float,
|
|
mscale_all_dim: float,
|
|
):
|
|
inv_freq = _create_inv_freq(dim, base, device)
|
|
super().__init__(
|
|
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
|
)
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
self.extrapolation_factor = extrapolation_factor
|
|
self.attn_factor = attn_factor
|
|
self.beta_fast = beta_fast
|
|
self.beta_slow = beta_slow
|
|
self.mscale_all_dim = mscale_all_dim
|
|
self.scaling_factor = scaling_factor
|
|
self.mscale = float(
|
|
get_mscale(self.scaling_factor, mscale)
|
|
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
|
* self.attn_factor
|
|
) # Get n-d magnitude scaling corrected for interpolation
|
|
|
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
if seqlen > self.max_position_embeddings or True:
|
|
inv_freq_extrapolation = _create_inv_freq(
|
|
self.dim, self.base, self.inv_freq.device
|
|
)
|
|
freqs = 1.0 / inv_freq_extrapolation
|
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
|
low, high = find_correction_range(
|
|
self.beta_fast,
|
|
self.beta_slow,
|
|
self.dim,
|
|
self.base,
|
|
self.max_position_embeddings,
|
|
)
|
|
|
|
inv_freq_mask = (
|
|
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
|
inv_freq = (
|
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
|
+ inv_freq_extrapolation * inv_freq_mask
|
|
)
|
|
|
|
self.inv_freq = inv_freq
|
|
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
# Don't do einsum, it converts fp32 to fp16
|
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
|
|
|
|
|
def apply_llama3_scaling(
|
|
freqs: torch.Tensor,
|
|
*,
|
|
scaling_factor: int,
|
|
low_freq_factor: int,
|
|
high_freq_factor: int,
|
|
original_max_position_embeddings: int,
|
|
):
|
|
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
|
|
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
|
|
new_freqs = []
|
|
|
|
for freq in freqs:
|
|
wavelen = 2 * math.pi / freq
|
|
|
|
if wavelen < high_freq_wavelen:
|
|
new_freqs.append(freq)
|
|
elif wavelen > low_freq_wavelen:
|
|
new_freqs.append(freq / scaling_factor)
|
|
else:
|
|
assert low_freq_wavelen != high_freq_wavelen
|
|
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
|
|
high_freq_factor - low_freq_factor
|
|
)
|
|
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
|
|
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
|
|
|
|
|
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
inv_freq: torch.Tensor,
|
|
scaling_factor: float,
|
|
sections: list,
|
|
max_position_embeddings,
|
|
):
|
|
self.sections = sections
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self.section_indices = (
|
|
torch.arange(len(self.sections))
|
|
.repeat_interleave(torch.tensor(self.sections))
|
|
.view(1, 1, -1)
|
|
.to(inv_freq.device)
|
|
)
|
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
|
|
|
def _update_cos_sin_cache(
|
|
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
|
):
|
|
# always cache the cos/sin for the full sequence length to avoid
|
|
# recomputing if the sequence length is smaller than the cached one
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
|
|
|
def get_cos_sin(
|
|
self,
|
|
position_ids: torch.Tensor,
|
|
):
|
|
slen = position_ids.shape[0]
|
|
|
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
|
return cos, sin
|