feat: add quant to mixtral (#1337)

This commit is contained in:
OlivierDehaene 2023-12-12 17:55:03 +01:00 committed by GitHub
parent ec6d4592d5
commit 82670d9786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 184 additions and 35 deletions

View File

@ -434,8 +434,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
def forward( def forward(
self, self,
@ -454,7 +452,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices] slots = slots[prefill_cache_indices]
else: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) max_s = min(self.max_past, max_s)

View File

@ -365,9 +365,9 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t() self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t() self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
self.offsets = None self.offsets = None
self.offsets_block_rows = 0 self.offsets_block_rows = 0
@ -467,8 +467,7 @@ class BlockSparseMoE(nn.Module):
return indices, bin_ids, bins, padded_bins, tokens_per_expert return indices, bin_ids, bins, padded_bins, tokens_per_expert
@torch.inference_mode() def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
x: (sequence_length, model_dim) x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts) gate_logits: (sequence_length, n_experts)
@ -502,8 +501,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts) # (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix( x = stk.Matrix(
topo.size(), topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.w3, topo).data, * stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices, topo.row_indices,
topo.column_indices, topo.column_indices,
topo.offsets, topo.offsets,
@ -534,6 +533,156 @@ class BlockSparseMoE(nn.Module):
return x.view(*input_shape) return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
class DenseMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.w1 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w3 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w2 = [
TensorParallelRowLinear.load(
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
@ -543,9 +692,9 @@ class MixtralLayer(nn.Module):
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
) self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -591,9 +740,9 @@ class MixtralLayer(nn.Module):
attn_output, res attn_output, res
) )
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output) moe_output = self.moe(normed_attn_res_output)
return block_sparse_moe_output, attn_res return moe_output, attn_res
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
@ -675,8 +824,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
def forward( def forward(
self, self,
@ -695,7 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices] slots = slots[prefill_cache_indices]
else: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) max_s = min(self.max_past, max_s)

View File

@ -136,9 +136,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min( needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS if SLIDING_WINDOW_BLOCKS is not None:
) needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
blocks += needed_blocks blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens)) needed_blocks_slots.append((needed_blocks, total_tokens))
@ -152,12 +152,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
request_prefill_cache_indices = torch.arange( if SLIDING_WINDOW is not None:
cumulative_length + max(0, input_length - SLIDING_WINDOW), request_prefill_cache_indices = torch.arange(
cumulative_length + input_length, cumulative_length + max(0, input_length - SLIDING_WINDOW),
dtype=torch.int64, cumulative_length + input_length,
) dtype=torch.int64,
prefill_cache_indices.append(request_prefill_cache_indices) )
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -209,12 +210,14 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices) if SLIDING_WINDOW is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0] if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
@ -222,7 +225,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
prefill_cache_indices = prefill_cache_indices.to(device) prefill_cache_indices = (
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor( input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device input_lengths, dtype=torch.int32, device=device
@ -314,8 +319,9 @@ class BaseFlashMistral(FlashCausalLM):
config.quantize = quantize config.quantize = quantize
# Set context windows # Set context windows
SLIDING_WINDOW = config.sliding_window if config.sliding_window is not None:
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) SLIDING_WINDOW = config.sliding_window
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -64,8 +64,6 @@ elif CAN_EXLLAMA:
except ImportError: except ImportError:
pass pass
from typing import Optional
HAS_EETQ = False HAS_EETQ = False
try: try:
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
@ -489,9 +487,9 @@ class TensorParallelRowLinear(SuperLayer):
process_group=weights.process_group, process_group=weights.process_group,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1: if self.process_group.size() > 1 and reduce:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out