mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Move the _update_cos_sin_cache into get_cos_sin
If the function is in init of class, the function is invoked layer number times. Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
70217ac345
commit
8b9a503f8a
@ -36,9 +36,6 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -270,7 +267,9 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, position_ids.device, seqlen=position_ids.shape[-1]
|
||||
)
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
@ -298,9 +297,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -354,9 +350,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user