mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Re-enabling dim=dim in TensorParallelColumn because llama.
This commit is contained in:
parent
ae308f88ec
commit
3fb8979a6d
@ -221,11 +221,11 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize)
|
||||
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=0)
|
||||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
|
@ -83,7 +83,7 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str):
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||
@ -102,7 +102,7 @@ class Weights:
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=1)
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
|
Loading…
Reference in New Issue
Block a user