Update conditionals for dynamic scaling

This commit is contained in:
Ian 2023-07-17 01:17:02 +00:00
parent f01c11bd0c
commit 0ec4d8182f
5 changed files with 7 additions and 23 deletions

View File

@ -115,7 +115,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
if self.scale_factor > 1:
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling

View File

@ -45,13 +45,8 @@ from text_generation_server.utils.layers import (
get_linear,
)
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
def load_row(config, prefix: str, weights, bias: bool):
@ -114,7 +109,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
if self.scale_factor > 1:
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings

View File

@ -26,11 +26,7 @@ from text_generation_server.utils.layers import (
)
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

View File

@ -60,14 +60,8 @@ if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int

View File

@ -423,10 +423,9 @@ try:
if self.dynamic_scaling:
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
max_seq_len = self.original_max_seq_len * scale_factor
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
self.register_buffer("inv_freq", inv_freq)
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
if self.scale_factor > 1:
if self.scale_factor > 1 and not self.dynamic_scaling:
length = max(seqlen, max_seq_len)
if (