fix gpt2 tests - some weights were not contiguous

This commit is contained in:
Felix Marty 2024-06-13 08:09:52 +00:00 committed by Nicolas Patry
parent 9e50c117bc
commit 5b6b257756
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674

View File

@ -106,7 +106,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
for i in range(3):
tensor = slice_[:, start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=1).T
weight = torch.cat(tensors, dim=1).T.contiguous()
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
@ -139,7 +139,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous()
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
@ -159,7 +159,7 @@ def load_col(config, prefix: str, weights, bias: bool):
[prefix], quantize=config.quantize, dim=1
)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous()
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)