mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: add quant to mixtral (#1337)
This commit is contained in:
parent
ec6d4592d5
commit
82670d9786
@ -434,8 +434,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
if self.max_past is None:
|
|
||||||
raise ValueError("max_past cannot be None")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -454,7 +452,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
slots = slots[prefill_cache_indices]
|
slots = slots[prefill_cache_indices]
|
||||||
else:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
max_s = min(self.max_past, max_s)
|
max_s = min(self.max_past, max_s)
|
||||||
|
@ -365,9 +365,9 @@ class BlockSparseMoE(nn.Module):
|
|||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
|
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
|
||||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
|
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
|
||||||
|
|
||||||
self.offsets = None
|
self.offsets = None
|
||||||
self.offsets_block_rows = 0
|
self.offsets_block_rows = 0
|
||||||
@ -467,8 +467,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
|
|
||||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||||
|
|
||||||
@torch.inference_mode()
|
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
x: (sequence_length, model_dim)
|
x: (sequence_length, model_dim)
|
||||||
gate_logits: (sequence_length, n_experts)
|
gate_logits: (sequence_length, n_experts)
|
||||||
@ -502,8 +501,8 @@ class BlockSparseMoE(nn.Module):
|
|||||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||||
x = stk.Matrix(
|
x = stk.Matrix(
|
||||||
topo.size(),
|
topo.size(),
|
||||||
self.act(stk.ops.sdd(x, self.w1, topo).data)
|
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||||
* stk.ops.sdd(x, self.w3, topo).data,
|
* stk.ops.sdd(x, self.w3.t(), topo).data,
|
||||||
topo.row_indices,
|
topo.row_indices,
|
||||||
topo.column_indices,
|
topo.column_indices,
|
||||||
topo.offsets,
|
topo.offsets,
|
||||||
@ -534,6 +533,156 @@ class BlockSparseMoE(nn.Module):
|
|||||||
|
|
||||||
return x.view(*input_shape)
|
return x.view(*input_shape)
|
||||||
|
|
||||||
|
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gate_logits: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# gate_logits: (sequence_length, n_experts)
|
||||||
|
gate_logits = self.gate(x)
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
|
||||||
|
if self.top_k < self.num_experts:
|
||||||
|
_, not_selected_experts = torch.topk(
|
||||||
|
all_probs,
|
||||||
|
self.num_experts - self.top_k,
|
||||||
|
largest=False,
|
||||||
|
sorted=False,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
# Mask not selected experts
|
||||||
|
all_probs.scatter_(1, not_selected_experts, 0)
|
||||||
|
|
||||||
|
# Re-normalize
|
||||||
|
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Expand to [num_experts, sequence_length, model_dim]
|
||||||
|
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||||
|
|
||||||
|
# Permute to [num_experts, model_dim, ffn_dim]
|
||||||
|
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||||
|
0, 2, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||||
|
|
||||||
|
out = torch.bmm(
|
||||||
|
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||||
|
)
|
||||||
|
# Mask not selected experts
|
||||||
|
out *= weights.t().view(self.num_experts, -1, 1)
|
||||||
|
|
||||||
|
# Sum experts
|
||||||
|
out = out.sum(0)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if len(x) > 256:
|
||||||
|
return self.sparse_forward(x)
|
||||||
|
# This is faster when there is not a lot of tokens
|
||||||
|
return self.dense_forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoE(nn.Module):
|
||||||
|
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
act = config.hidden_act
|
||||||
|
if "gelu" in act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none",
|
||||||
|
)
|
||||||
|
elif "silu" in act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[act]
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.w1 = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
for i in range(self.num_experts)
|
||||||
|
]
|
||||||
|
self.w3 = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
for i in range(self.num_experts)
|
||||||
|
]
|
||||||
|
self.w2 = [
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
for i in range(self.num_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gate_logits: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# gate_logits: (sequence_length, n_experts)
|
||||||
|
gate_logits = self.gate(x)
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
|
||||||
|
if self.top_k < self.num_experts:
|
||||||
|
_, not_selected_experts = torch.topk(
|
||||||
|
all_probs,
|
||||||
|
self.num_experts - self.top_k,
|
||||||
|
largest=False,
|
||||||
|
sorted=False,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
# Mask not selected experts
|
||||||
|
all_probs.scatter_(1, not_selected_experts, 0)
|
||||||
|
|
||||||
|
# Re-normalize
|
||||||
|
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Final output tensor
|
||||||
|
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
||||||
|
h = self.w2[i](h, reduce=False)
|
||||||
|
# Add expert output to out with masking
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
@ -543,9 +692,9 @@ class MixtralLayer(nn.Module):
|
|||||||
self.self_attn = MixtralAttention(
|
self.self_attn = MixtralAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = BlockSparseMoE(
|
|
||||||
f"{prefix}.block_sparse_moe", config, weights
|
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
||||||
)
|
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
@ -591,9 +740,9 @@ class MixtralLayer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
|
moe_output = self.moe(normed_attn_res_output)
|
||||||
|
|
||||||
return block_sparse_moe_output, attn_res
|
return moe_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(torch.nn.Module):
|
class MixtralModel(torch.nn.Module):
|
||||||
@ -675,8 +824,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
if self.max_past is None:
|
|
||||||
raise ValueError("max_past cannot be None")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -695,7 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
slots = slots[prefill_cache_indices]
|
slots = slots[prefill_cache_indices]
|
||||||
else:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
max_s = min(self.max_past, max_s)
|
max_s = min(self.max_past, max_s)
|
||||||
|
@ -136,9 +136,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
needed_blocks = min(
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
|
if SLIDING_WINDOW_BLOCKS is not None:
|
||||||
)
|
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
|
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
@ -152,12 +152,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
request_prefill_cache_indices = torch.arange(
|
if SLIDING_WINDOW is not None:
|
||||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
request_prefill_cache_indices = torch.arange(
|
||||||
cumulative_length + input_length,
|
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
||||||
dtype=torch.int64,
|
cumulative_length + input_length,
|
||||||
)
|
dtype=torch.int64,
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
)
|
||||||
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
@ -209,12 +210,14 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
if SLIDING_WINDOW is not None:
|
||||||
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
if SLIDING_WINDOW is not None:
|
||||||
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
@ -222,7 +225,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
prefill_cache_indices = prefill_cache_indices.to(device)
|
prefill_cache_indices = (
|
||||||
|
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
|
||||||
|
)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
input_lengths_tensor = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
input_lengths, dtype=torch.int32, device=device
|
input_lengths, dtype=torch.int32, device=device
|
||||||
@ -314,8 +319,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
SLIDING_WINDOW = config.sliding_window
|
if config.sliding_window is not None:
|
||||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
SLIDING_WINDOW = config.sliding_window
|
||||||
|
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -64,8 +64,6 @@ elif CAN_EXLLAMA:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
HAS_EETQ = False
|
HAS_EETQ = False
|
||||||
try:
|
try:
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
@ -489,9 +487,9 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1 and reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user