fix: prefer llama base model and improve rotary logic

This commit is contained in:
drbh 2024-09-02 16:19:45 +00:00 committed by Daniël de Kok
parent 853bc514f2
commit dff1b9f795
4 changed files with 224 additions and 263 deletions

View File

@ -90,21 +90,43 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "longrope":
inv_freq = apply_phi3_scaling(
inv_freq,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
short_factor=rope_scaling["short_factor"],
long_factor=rope_scaling["long_factor"],
short_mscale=rope_scaling["short_mscale"],
long_mscale=rope_scaling["long_mscale"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
device=inv_freq.device,
dim=dim,
# Phi3LongRoPEScaledRotaryEmbedding
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
return cls(inv_freq, scaling_factor)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
original_max_position_embeddings = (
config.original_max_position_embeddings
)
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=1.0
/ (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
long_inv_freq=1.0
/ (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
),
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
@ -324,6 +346,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq,
long_inv_freq,
max_position_embeddings,
short_mscale,
long_mscale,
original_max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
@ -491,58 +570,3 @@ def apply_llama3_scaling(
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def apply_phi3_scaling(
freqs: torch.Tensor,
*,
max_position_embeddings: int,
rope_theta: int,
short_factor: torch.Tensor,
long_factor: torch.Tensor,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
device=None,
dim=None,
):
base = rope_theta
long_rescale_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
short_rescale_factors = torch.tensor(
short_factor, dtype=torch.float32, device=device
)
long_inv_freq = 1.0 / (
long_rescale_factors
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
)
short_inv_freq = 1.0 / (
short_rescale_factors
* (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
)
# original_max_position_embeddings = torch.tensor(original_max_position_embeddings, device=device)
# low_freq_factor = torch.tensor(long_factor, dtype=torch.float32, device=device)
# high_freq_factor = torch.tensor(short_factor, dtype=torch.float32, device=device)
# low_freq_wavelen = original_max_position_embeddings / low_freq_factor
# high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
# if wavelen < high_freq_wavelen:
if True:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / short_mscale)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append(
(1 - smooth) * freq / short_mscale + smooth * freq / long_mscale
)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

View File

@ -777,7 +777,7 @@ def get_model(
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,

View File

@ -47,11 +47,17 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.layers import (
FastLinear,
)
from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
if SYSTEM == "rocm":
try:
from vllm import _custom_C
@ -245,6 +251,103 @@ class FlashLlamaAttention(torch.nn.Module):
)
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
self.w13,
self.w2,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights, index):
super().__init__()
@ -353,9 +456,14 @@ class FlashLlamaLayer(nn.Module):
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
if self.use_moe:
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
else:
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -401,7 +509,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
mlp_output = self.dense(normed_attn_res_output, adapter_data)
return mlp_output, attn_res

View File

@ -27,15 +27,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers import (
FastLinear,
)
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
class PhiConfig(PretrainedConfig):
def __init__(
@ -78,16 +69,7 @@ class PhiConfig(PretrainedConfig):
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct":
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
if config.num_attention_heads != config.num_key_value_heads \
and False:
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
@ -108,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize and (config.quantize not in ["gptq", "awq", "marlin"]):
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
@ -123,103 +105,6 @@ def _load_gqa(config, prefix: str, weights):
return TensorParallelColumnLinear(get_linear(weight, bias=True))
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
self.w13,
self.w2,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
@ -233,12 +118,7 @@ class FlashPhiAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
self.head_dim = self.head_size
if hasattr(config, "partial_rotary_factor"):
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
else:
self.rotary_dim = self.head_size
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
@ -261,10 +141,9 @@ class FlashPhiAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
# in llama the dense layer is called "o_proj" and has bias=False
proj_layer_name = f"{prefix}.o_proj" if self.use_moe else f"{prefix}.dense"
self.dense = TensorParallelRowLinear.load(
config,
prefix=proj_layer_name,
prefix=f"{prefix}.dense",
weights=weights,
bias=True,
)
@ -304,14 +183,9 @@ class FlashPhiAttention(torch.nn.Module):
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place
if self.use_moe:
# rotate half rotary_dim
# half_size = torch.select(kv, dim=1, index=0).size(1) // 2
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
else:
self.rotary_emb(
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
self.rotary_emb(
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
@ -384,30 +258,13 @@ class FlashPhiLayer(nn.Module):
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
if self.use_moe:
self.moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
else:
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
# eps=config.layer_norm_eps,
eps=1e-5,
eps=config.layer_norm_eps,
)
if self.use_moe:
self.post_attn_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=1e-5,
)
# self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
self.resid_dropout = torch.nn.Dropout(config.attention_dropout)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward(
self,
@ -436,28 +293,9 @@ class FlashPhiLayer(nn.Module):
max_s,
)
if self.use_moe:
# hidden_states = attn_output
# if residual is not None:
# hidden_states = hidden_states + residual
# residual = hidden_states
# hidden_states, router_states = self.post_attn_layernorm(
# hidden_states
# )
# hidden_states = self.moe(hidden_states)
# # hidden_states = hidden_states + residual
# res = residual
hidden_states = self.resid_dropout(attn_output)
_hidden_states = self.resid_dropout(self.moe(hidden_states))
hidden_states = hidden_states.add(_hidden_states)
else:
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
return hidden_states, res
@ -489,20 +327,11 @@ class FlashPhiModel(torch.nn.Module):
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
if self.use_moe:
self.norm = FastLayerNorm.load(
prefix="model.norm",
weights=weights,
eps=1e-5,
)
else:
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,