mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
add ep
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
debf477ba4
commit
3db50ed9d3
@ -409,6 +409,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.fp8 import (
|
||||
@ -46,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.world_size = weights.process_group.size()
|
||||
self.rank = weights.process_group.rank()
|
||||
self.ep_rank = self.rank
|
||||
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
|
||||
|
||||
if self.use_ep:
|
||||
n_experts = (n_experts + self.world_size - 1) // self.world_size
|
||||
self.ep_offset = self.ep_rank * n_experts
|
||||
else:
|
||||
self.ep_offset = 0
|
||||
|
||||
(
|
||||
self.gate_up_proj,
|
||||
@ -57,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
|
||||
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||
@ -65,6 +78,8 @@ class FP8SparseMoELayer(nn.Module):
|
||||
n_experts=n_experts,
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
use_ep=self.use_ep,
|
||||
ep_offset=self.ep_offset,
|
||||
)
|
||||
)
|
||||
if self.weight_block_size is not None:
|
||||
@ -99,8 +114,15 @@ class FP8SparseMoELayer(nn.Module):
|
||||
)
|
||||
total_num_experts = gating_output.size(-1)
|
||||
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
|
||||
moe_n_slice = (total_num_experts + 31) // 32
|
||||
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
|
||||
|
||||
if self.use_ep:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (
|
||||
total_num_experts + self.world_size - 1
|
||||
) // self.world_size
|
||||
else:
|
||||
moe_n_slice = 1
|
||||
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
|
||||
for i in range(moe_n_slice):
|
||||
min_expert = i * n_expert_slice
|
||||
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
|
||||
@ -130,8 +152,8 @@ class FP8SparseMoELayer(nn.Module):
|
||||
d_scale_w3=w2_weight_scale,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
experts_min=min_expert,
|
||||
experts_max=max_expert - 1,
|
||||
experts_min=min_expert + self.ep_offset,
|
||||
experts_max=max_expert + self.ep_offset - 1,
|
||||
)
|
||||
htorch.core.mark_step()
|
||||
if i == 0:
|
||||
@ -148,13 +170,14 @@ def _load_expert_weights(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight_scales = None
|
||||
max_input_scale = None
|
||||
|
||||
for i in range(n_experts):
|
||||
weight = get_weight_fn(prefix, i, name, weights)
|
||||
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
|
||||
|
||||
assert isinstance(weight, Fp8Weight)
|
||||
|
||||
@ -197,14 +220,26 @@ def _load_expert_multi_weights_col(
|
||||
gate_proj_name: str,
|
||||
up_proj_name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_multi_weights(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=None,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
||||
|
||||
@ -214,10 +249,20 @@ def _load_expert_weights_row(
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
use_ep: bool = False,
|
||||
ep_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
def get_weight_fn_sharded(prefix, i, name, weights):
|
||||
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_weights(f"{prefix}.{i}.{name}")
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||
get_weight_fn if use_ep else get_weight_fn_sharded,
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=name,
|
||||
weights=weights,
|
||||
ep_offset=ep_offset if use_ep else 0,
|
||||
)
|
||||
|
@ -62,6 +62,14 @@ class WeightsLoader(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
"""
|
||||
Get the weights at the given prefixes, column-split them for tensor
|
||||
parallelim, and then concatenate the weights along the given dimension.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
@ -390,6 +402,9 @@ class Weights:
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
def get_multi_weights(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights(self, prefixes, dim)
|
||||
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user