mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-18 23:32:06 +00:00
[gaudi] Move the _update_cos_sin_cache into get_cos_sin (#3254)
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
613b8dd647
commit
25fdc5f03c
@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
self._update_cos_sin_cache(
|
self.max_position_embeddings = max_position_embeddings
|
||||||
torch.float32, inv_freq.device, max_position_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
def get_cos_sin(self, position_ids: torch.Tensor):
|
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||||
|
)
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
|
|
||||||
@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
self._update_cos_sin_cache(
|
|
||||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
self._update_cos_sin_cache(
|
|
||||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
if (
|
if (
|
||||||
@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
):
|
):
|
||||||
slen = position_ids.shape[0]
|
slen = position_ids.shape[0]
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||||
|
Loading…
Reference in New Issue
Block a user