mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Update conditionals for dynamic scaling
This commit is contained in:
parent
f01c11bd0c
commit
0ec4d8182f
@ -115,7 +115,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.scale_factor = ROPE_SCALE_FACTOR
|
||||
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||
|
||||
if self.scale_factor > 1:
|
||||
if self.scale_factor > 1 or self.dynamic_scaling:
|
||||
# Base before scaling is 10000 per the original RoPE paper
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling
|
||||
|
@ -45,13 +45,8 @@ from text_generation_server.utils.layers import (
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||
|
||||
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||
ROPE_DYNAMIC_SCALING = True
|
||||
else:
|
||||
ROPE_DYNAMIC_SCALING = False
|
||||
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
@ -114,7 +109,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.scale_factor = ROPE_SCALE_FACTOR
|
||||
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
|
||||
|
||||
if self.scale_factor > 1:
|
||||
if self.scale_factor > 1 or self.dynamic_scaling:
|
||||
# Base before scaling is 10000 per the original RoPE paper
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings
|
||||
|
@ -26,11 +26,7 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
|
||||
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||
|
||||
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||
ROPE_DYNAMIC_SCALING = True
|
||||
else:
|
||||
ROPE_DYNAMIC_SCALING = False
|
||||
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
@ -60,14 +60,8 @@ if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
if not CUSTOM_KERNELS_ENABLED:
|
||||
logger.warning("We're not using custom kernels.")
|
||||
|
||||
|
||||
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
|
||||
|
||||
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
|
||||
ROPE_DYNAMIC_SCALING = True
|
||||
else:
|
||||
ROPE_DYNAMIC_SCALING = False
|
||||
|
||||
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
|
||||
|
||||
def make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
|
@ -423,10 +423,9 @@ try:
|
||||
if self.dynamic_scaling:
|
||||
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
|
||||
max_seq_len = self.original_max_seq_len * scale_factor
|
||||
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
|
||||
|
||||
if self.scale_factor > 1:
|
||||
if self.scale_factor > 1 and not self.dynamic_scaling:
|
||||
length = max(seqlen, max_seq_len)
|
||||
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user