mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
remove redundant changes
This commit is contained in:
parent
8094de91fc
commit
04aab711a7
@ -286,17 +286,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
def get_multi_weights_col(
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
|
|
||||||
):
|
|
||||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
if flag:
|
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
for p in prefixes
|
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
|
|
||||||
shapes = [x.shape for x in w]
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
# Concat then send to the device
|
# Concat then send to the device
|
||||||
@ -360,13 +354,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
def get_weights_row(self, weights: "Weights", prefix: str, flag=True):
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
if flag:
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False)
|
|
||||||
else:
|
|
||||||
w = weights.get_sharded(f"{prefix}", dim=1, to_device=False)
|
|
||||||
|
|
||||||
w = w.to(weights.device)
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
if self.weight_block_size is not None:
|
if self.weight_block_size is not None:
|
||||||
|
@ -5,9 +5,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
from text_generation_server.utils.log import log_master
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
@ -115,38 +113,24 @@ def _load_expert_multi_weights_col(
|
|||||||
weights: Weights,
|
weights: Weights,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
all_weight = (
|
for i in range(n_experts):
|
||||||
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
|
weight = weights.get_multi_weights_col(
|
||||||
.weight.transpose(2, 1)
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
# for i in range(n_experts):
|
|
||||||
# # weight = weights.get_weights_col(
|
|
||||||
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
|
|
||||||
# # )
|
|
||||||
# # weight = weights.get_multi_weights_col(
|
|
||||||
# # [f"{prefix}.{gate_proj_name}", f"{prefix}.{up_proj_name}"], 0
|
|
||||||
# # )
|
|
||||||
|
|
||||||
# weight = weights.get_multi_weights_col(
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
# [f"{prefix}.gate_up_proj"], 0, flag=False
|
|
||||||
# )
|
|
||||||
|
|
||||||
# from pdb import set_trace; set_trace()
|
if all_weight is None:
|
||||||
# assert isinstance(weight, UnquantizedWeight)
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
# if all_weight is None:
|
all_weight[i] = weight.weight
|
||||||
# all_weight = torch.empty(
|
|
||||||
# (n_experts,) + weight.weight.shape,
|
|
||||||
# dtype=weight.weight.dtype,
|
|
||||||
# device=weight.weight.device,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# all_weight[i] = weight.weight
|
assert all_weight is not None
|
||||||
|
|
||||||
# assert all_weight is not None
|
|
||||||
|
|
||||||
log_master(logger.info, f"w1: {all_weight.shape}")
|
|
||||||
return all_weight
|
return all_weight
|
||||||
|
|
||||||
|
|
||||||
@ -158,29 +142,23 @@ def _load_expert_weights_row(
|
|||||||
weights: Weights,
|
weights: Weights,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
all_weight = (
|
for i in range(n_experts):
|
||||||
weights.get_weights_row(f"{prefix}.{name}", flag=False)
|
weight = weights.get_weights_row(
|
||||||
.weight.transpose(1, 2)
|
f"{prefix}.{i}.{name}",
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
# for i in range(n_experts):
|
|
||||||
# weight = weights.get_weights_row(
|
|
||||||
# f"{prefix}.{name}", flag=False
|
|
||||||
# )
|
|
||||||
|
|
||||||
# assert isinstance(weight, UnquantizedWeight)
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
# if all_weight is None:
|
if all_weight is None:
|
||||||
# all_weight = torch.empty(
|
all_weight = torch.empty(
|
||||||
# (n_experts,) + weight.weight.shape,
|
(n_experts,) + weight.weight.shape,
|
||||||
# dtype=weight.weight.dtype,
|
dtype=weight.weight.dtype,
|
||||||
# device=weight.weight.device,
|
device=weight.weight.device,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# all_weight[i] = weight.weight
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
assert all_weight is not None
|
assert all_weight is not None
|
||||||
log_master(logger.info, f"w2: {all_weight.shape}")
|
|
||||||
|
|
||||||
return all_weight
|
return all_weight
|
||||||
|
|
||||||
|
@ -250,8 +250,6 @@ class Weights:
|
|||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
elif dim == 1:
|
elif dim == 1:
|
||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[:, start:stop]
|
||||||
elif dim == 2:
|
|
||||||
tensor = slice_[:, :, start:stop]
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
@ -375,8 +373,8 @@ class Weights:
|
|||||||
def get_weights_col(self, prefix: str):
|
def get_weights_col(self, prefix: str):
|
||||||
return self.weights_loader.get_weights_col(self, prefix)
|
return self.weights_loader.get_weights_col(self, prefix)
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], dim: int, flag=True):
|
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
||||||
return self.weights_loader.get_multi_weights_col(self, prefixes, dim, flag=flag)
|
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
||||||
|
|
||||||
def get_tensor_shard(self, var, dim):
|
def get_tensor_shard(self, var, dim):
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
@ -394,8 +392,8 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_weights_row(self, prefix: str, flag=True):
|
def get_weights_row(self, prefix: str):
|
||||||
return self.weights_loader.get_weights_row(self, prefix, flag=flag)
|
return self.weights_loader.get_weights_row(self, prefix)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_loader(self, weights_loader: WeightsLoader):
|
def use_loader(self, weights_loader: WeightsLoader):
|
||||||
|
Loading…
Reference in New Issue
Block a user