From 04aab711a7f028ef473de9407b5da2953240ad5a Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Sat, 5 Apr 2025 15:42:31 +0000 Subject: [PATCH] remove redundant changes --- server/text_generation_server/layers/fp8.py | 23 ++---- .../layers/moe/unquantized.py | 74 +++++++------------ .../text_generation_server/utils/weights.py | 10 +-- 3 files changed, 36 insertions(+), 71 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 4043aea91..04689ed92 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -286,17 +286,11 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) - def get_multi_weights_col( - self, weights: "Weights", prefixes: List[str], dim: int, flag=True - ): + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet - if flag: - w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_device=False) - for p in prefixes - ] - else: - w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes] + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] shapes = [x.shape for x in w] # Concat then send to the device @@ -360,13 +354,8 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) - def get_weights_row(self, weights: "Weights", prefix: str, flag=True): - if flag: - w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False) - else: - w = weights.get_sharded(f"{prefix}", dim=1, to_device=False) - - w = w.to(weights.device) + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: if self.weight_block_size is not None: diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index a6ef467df..007f99d05 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -5,9 +5,7 @@ import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel -from text_generation_server.utils.weights import Weights -from text_generation_server.utils.log import log_master -from loguru import logger +from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE @@ -115,38 +113,24 @@ def _load_expert_multi_weights_col( weights: Weights, ) -> torch.Tensor: all_weight = None - all_weight = ( - weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False) - .weight.transpose(2, 1) - .contiguous() - ) - # for i in range(n_experts): - # # weight = weights.get_weights_col( - # # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj", - # # ) - # # weight = weights.get_multi_weights_col( - # # [f"{prefix}.{gate_proj_name}", f"{prefix}.{up_proj_name}"], 0 - # # ) + for i in range(n_experts): + weight = weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) - # weight = weights.get_multi_weights_col( - # [f"{prefix}.gate_up_proj"], 0, flag=False - # ) + assert isinstance(weight, UnquantizedWeight) - # from pdb import set_trace; set_trace() - # assert isinstance(weight, UnquantizedWeight) + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) - # if all_weight is None: - # all_weight = torch.empty( - # (n_experts,) + weight.weight.shape, - # dtype=weight.weight.dtype, - # device=weight.weight.device, - # ) + all_weight[i] = weight.weight - # all_weight[i] = weight.weight + assert all_weight is not None - # assert all_weight is not None - - log_master(logger.info, f"w1: {all_weight.shape}") return all_weight @@ -158,29 +142,23 @@ def _load_expert_weights_row( weights: Weights, ) -> torch.Tensor: all_weight = None - all_weight = ( - weights.get_weights_row(f"{prefix}.{name}", flag=False) - .weight.transpose(1, 2) - .contiguous() - ) - # for i in range(n_experts): - # weight = weights.get_weights_row( - # f"{prefix}.{name}", flag=False - # ) + for i in range(n_experts): + weight = weights.get_weights_row( + f"{prefix}.{i}.{name}", + ) - # assert isinstance(weight, UnquantizedWeight) + assert isinstance(weight, UnquantizedWeight) - # if all_weight is None: - # all_weight = torch.empty( - # (n_experts,) + weight.weight.shape, - # dtype=weight.weight.dtype, - # device=weight.weight.device, - # ) + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=weight.weight.dtype, + device=weight.weight.device, + ) - # all_weight[i] = weight.weight + all_weight[i] = weight.weight assert all_weight is not None - log_master(logger.info, f"w2: {all_weight.shape}") return all_weight diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index ef4117167..c03dd2b0d 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -250,8 +250,6 @@ class Weights: tensor = slice_[start:stop] elif dim == 1: tensor = slice_[:, start:stop] - elif dim == 2: - tensor = slice_[:, :, start:stop] else: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert @@ -375,8 +373,8 @@ class Weights: def get_weights_col(self, prefix: str): return self.weights_loader.get_weights_col(self, prefix) - def get_multi_weights_col(self, prefixes: List[str], dim: int, flag=True): - return self.weights_loader.get_multi_weights_col(self, prefixes, dim, flag=flag) + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -394,8 +392,8 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_weights_row(self, prefix: str, flag=True): - return self.weights_loader.get_weights_row(self, prefix, flag=flag) + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) @contextmanager def use_loader(self, weights_loader: WeightsLoader):