Re-enabling dim=dim in TensorParallelColumn because llama.

This commit is contained in:
Ubuntu 2023-06-13 15:37:52 +00:00 committed by Nicolas Patry
parent ae308f88ec
commit 3fb8979a6d
2 changed files with 4 additions and 4 deletions

View File

@ -221,11 +221,11 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize)
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=0)
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)

View File

@ -83,7 +83,7 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
@ -102,7 +102,7 @@ class Weights:
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=1)
weight = torch.cat(w, dim=dim)
return weight
def get_multi_weights_row(self, prefix: str, quantize: str):