remove redundant changes

This commit is contained in:
Mohit Sharma 2025-04-05 15:42:31 +00:00
parent 8094de91fc
commit 04aab711a7
3 changed files with 36 additions and 71 deletions

View File

@ -286,17 +286,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w)
def get_multi_weights_col(
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
):
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
if flag:
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
for p in prefixes
]
else:
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
@ -360,13 +354,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str, flag=True):
if flag:
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False)
else:
w = weights.get_sharded(f"{prefix}", dim=1, to_device=False)
w = w.to(weights.device)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:

View File

@ -5,9 +5,7 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.log import log_master
from loguru import logger
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
@ -115,38 +113,24 @@ def _load_expert_multi_weights_col(
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight = (
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
.weight.transpose(2, 1)
.contiguous()
)
# for i in range(n_experts):
# # weight = weights.get_weights_col(
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
# # )
# # weight = weights.get_multi_weights_col(
# # [f"{prefix}.{gate_proj_name}", f"{prefix}.{up_proj_name}"], 0
# # )
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
# weight = weights.get_multi_weights_col(
# [f"{prefix}.gate_up_proj"], 0, flag=False
# )
assert isinstance(weight, UnquantizedWeight)
# from pdb import set_trace; set_trace()
# assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
# if all_weight is None:
# all_weight = torch.empty(
# (n_experts,) + weight.weight.shape,
# dtype=weight.weight.dtype,
# device=weight.weight.device,
# )
all_weight[i] = weight.weight
# all_weight[i] = weight.weight
assert all_weight is not None
# assert all_weight is not None
log_master(logger.info, f"w1: {all_weight.shape}")
return all_weight
@ -158,29 +142,23 @@ def _load_expert_weights_row(
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight = (
weights.get_weights_row(f"{prefix}.{name}", flag=False)
.weight.transpose(1, 2)
.contiguous()
)
# for i in range(n_experts):
# weight = weights.get_weights_row(
# f"{prefix}.{name}", flag=False
# )
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
# assert isinstance(weight, UnquantizedWeight)
assert isinstance(weight, UnquantizedWeight)
# if all_weight is None:
# all_weight = torch.empty(
# (n_experts,) + weight.weight.shape,
# dtype=weight.weight.dtype,
# device=weight.weight.device,
# )
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
# all_weight[i] = weight.weight
all_weight[i] = weight.weight
assert all_weight is not None
log_master(logger.info, f"w2: {all_weight.shape}")
return all_weight

View File

@ -250,8 +250,6 @@ class Weights:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
elif dim == 2:
tensor = slice_[:, :, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
@ -375,8 +373,8 @@ class Weights:
def get_weights_col(self, prefix: str):
return self.weights_loader.get_weights_col(self, prefix)
def get_multi_weights_col(self, prefixes: List[str], dim: int, flag=True):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim, flag=flag)
def get_multi_weights_col(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
@ -394,8 +392,8 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_weights_row(self, prefix: str, flag=True):
return self.weights_loader.get_weights_row(self, prefix, flag=flag)
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):