mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
remove redundant changes
This commit is contained in:
parent
8094de91fc
commit
04aab711a7
@ -286,17 +286,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights_col(
|
||||
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
|
||||
):
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
if flag:
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
|
||||
for p in prefixes
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
else:
|
||||
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
@ -360,13 +354,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str, flag=True):
|
||||
if flag:
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False)
|
||||
else:
|
||||
w = weights.get_sharded(f"{prefix}", dim=1, to_device=False)
|
||||
|
||||
w = w.to(weights.device)
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
|
@ -5,9 +5,7 @@ import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
@ -115,38 +113,24 @@ def _load_expert_multi_weights_col(
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight = (
|
||||
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
|
||||
.weight.transpose(2, 1)
|
||||
.contiguous()
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
# for i in range(n_experts):
|
||||
# # weight = weights.get_weights_col(
|
||||
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
|
||||
# # )
|
||||
# # weight = weights.get_multi_weights_col(
|
||||
# # [f"{prefix}.{gate_proj_name}", f"{prefix}.{up_proj_name}"], 0
|
||||
# # )
|
||||
|
||||
# weight = weights.get_multi_weights_col(
|
||||
# [f"{prefix}.gate_up_proj"], 0, flag=False
|
||||
# )
|
||||
assert isinstance(weight, UnquantizedWeight)
|
||||
|
||||
# from pdb import set_trace; set_trace()
|
||||
# assert isinstance(weight, UnquantizedWeight)
|
||||
if all_weight is None:
|
||||
all_weight = torch.empty(
|
||||
(n_experts,) + weight.weight.shape,
|
||||
dtype=weight.weight.dtype,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
|
||||
# if all_weight is None:
|
||||
# all_weight = torch.empty(
|
||||
# (n_experts,) + weight.weight.shape,
|
||||
# dtype=weight.weight.dtype,
|
||||
# device=weight.weight.device,
|
||||
# )
|
||||
all_weight[i] = weight.weight
|
||||
|
||||
# all_weight[i] = weight.weight
|
||||
assert all_weight is not None
|
||||
|
||||
# assert all_weight is not None
|
||||
|
||||
log_master(logger.info, f"w1: {all_weight.shape}")
|
||||
return all_weight
|
||||
|
||||
|
||||
@ -158,29 +142,23 @@ def _load_expert_weights_row(
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight = (
|
||||
weights.get_weights_row(f"{prefix}.{name}", flag=False)
|
||||
.weight.transpose(1, 2)
|
||||
.contiguous()
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_weights_row(
|
||||
f"{prefix}.{i}.{name}",
|
||||
)
|
||||
# for i in range(n_experts):
|
||||
# weight = weights.get_weights_row(
|
||||
# f"{prefix}.{name}", flag=False
|
||||
# )
|
||||
|
||||
# assert isinstance(weight, UnquantizedWeight)
|
||||
assert isinstance(weight, UnquantizedWeight)
|
||||
|
||||
# if all_weight is None:
|
||||
# all_weight = torch.empty(
|
||||
# (n_experts,) + weight.weight.shape,
|
||||
# dtype=weight.weight.dtype,
|
||||
# device=weight.weight.device,
|
||||
# )
|
||||
if all_weight is None:
|
||||
all_weight = torch.empty(
|
||||
(n_experts,) + weight.weight.shape,
|
||||
dtype=weight.weight.dtype,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
|
||||
# all_weight[i] = weight.weight
|
||||
all_weight[i] = weight.weight
|
||||
|
||||
assert all_weight is not None
|
||||
log_master(logger.info, f"w2: {all_weight.shape}")
|
||||
|
||||
return all_weight
|
||||
|
||||
|
@ -250,8 +250,6 @@ class Weights:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = slice_[:, start:stop]
|
||||
elif dim == 2:
|
||||
tensor = slice_[:, :, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
@ -375,8 +373,8 @@ class Weights:
|
||||
def get_weights_col(self, prefix: str):
|
||||
return self.weights_loader.get_weights_col(self, prefix)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], dim: int, flag=True):
|
||||
return self.weights_loader.get_multi_weights_col(self, prefixes, dim, flag=flag)
|
||||
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
@ -394,8 +392,8 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_weights_row(self, prefix: str, flag=True):
|
||||
return self.weights_loader.get_weights_row(self, prefix, flag=flag)
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
|
Loading…
Reference in New Issue
Block a user