mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-02 22:10:17 +00:00
Re-enabling dim=dim in TensorParallelColumn because llama.
This commit is contained in:
parent
ae308f88ec
commit
3fb8979a6d
@ -221,11 +221,11 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||||
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize)
|
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||||
bias = torch.cat(b, dim=0)
|
bias = torch.cat(b, dim=dim)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
@ -83,7 +83,7 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||||
@ -102,7 +102,7 @@ class Weights:
|
|||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=1)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user