remove redundant changes

This commit is contained in:
Mohit Sharma 2025-04-05 15:42:31 +00:00
parent 8094de91fc
commit 04aab711a7
3 changed files with 36 additions and 71 deletions

View File

@ -286,17 +286,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights_col( def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
if flag:
w = [ w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
for p in prefixes
] ]
else:
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
shapes = [x.shape for x in w] shapes = [x.shape for x in w]
# Concat then send to the device # Concat then send to the device
@ -360,13 +354,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str, flag=True): def get_weights_row(self, weights: "Weights", prefix: str):
if flag: w = weights.get_sharded(f"{prefix}.weight", dim=1)
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False)
else:
w = weights.get_sharded(f"{prefix}", dim=1, to_device=False)
w = w.to(weights.device)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None: if self.weight_block_size is not None:

View File

@ -5,9 +5,7 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
from text_generation_server.utils.log import log_master
from loguru import logger
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
@ -115,38 +113,24 @@ def _load_expert_multi_weights_col(
weights: Weights, weights: Weights,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight = ( for i in range(n_experts):
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False) weight = weights.get_multi_weights_col(
.weight.transpose(2, 1) [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
.contiguous()
) )
# for i in range(n_experts):
# # weight = weights.get_weights_col(
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
# # )
# # weight = weights.get_multi_weights_col(
# # [f"{prefix}.{gate_proj_name}", f"{prefix}.{up_proj_name}"], 0
# # )
# weight = weights.get_multi_weights_col( assert isinstance(weight, UnquantizedWeight)
# [f"{prefix}.gate_up_proj"], 0, flag=False
# )
# from pdb import set_trace; set_trace() if all_weight is None:
# assert isinstance(weight, UnquantizedWeight) all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
# if all_weight is None: all_weight[i] = weight.weight
# all_weight = torch.empty(
# (n_experts,) + weight.weight.shape,
# dtype=weight.weight.dtype,
# device=weight.weight.device,
# )
# all_weight[i] = weight.weight assert all_weight is not None
# assert all_weight is not None
log_master(logger.info, f"w1: {all_weight.shape}")
return all_weight return all_weight
@ -158,29 +142,23 @@ def _load_expert_weights_row(
weights: Weights, weights: Weights,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight = ( for i in range(n_experts):
weights.get_weights_row(f"{prefix}.{name}", flag=False) weight = weights.get_weights_row(
.weight.transpose(1, 2) f"{prefix}.{i}.{name}",
.contiguous()
) )
# for i in range(n_experts):
# weight = weights.get_weights_row(
# f"{prefix}.{name}", flag=False
# )
# assert isinstance(weight, UnquantizedWeight) assert isinstance(weight, UnquantizedWeight)
# if all_weight is None: if all_weight is None:
# all_weight = torch.empty( all_weight = torch.empty(
# (n_experts,) + weight.weight.shape, (n_experts,) + weight.weight.shape,
# dtype=weight.weight.dtype, dtype=weight.weight.dtype,
# device=weight.weight.device, device=weight.weight.device,
# ) )
# all_weight[i] = weight.weight all_weight[i] = weight.weight
assert all_weight is not None assert all_weight is not None
log_master(logger.info, f"w2: {all_weight.shape}")
return all_weight return all_weight

View File

@ -250,8 +250,6 @@ class Weights:
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif dim == 1: elif dim == 1:
tensor = slice_[:, start:stop] tensor = slice_[:, start:stop]
elif dim == 2:
tensor = slice_[:, :, start:stop]
else: else:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
@ -375,8 +373,8 @@ class Weights:
def get_weights_col(self, prefix: str): def get_weights_col(self, prefix: str):
return self.weights_loader.get_weights_col(self, prefix) return self.weights_loader.get_weights_col(self, prefix)
def get_multi_weights_col(self, prefixes: List[str], dim: int, flag=True): def get_multi_weights_col(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim, flag=flag) return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
def get_tensor_shard(self, var, dim): def get_tensor_shard(self, var, dim):
world_size = self.process_group.size() world_size = self.process_group.size()
@ -394,8 +392,8 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_weights_row(self, prefix: str, flag=True): def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix, flag=flag) return self.weights_loader.get_weights_row(self, prefix)
@contextmanager @contextmanager
def use_loader(self, weights_loader: WeightsLoader): def use_loader(self, weights_loader: WeightsLoader):