mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix gpt2 tests - some weights were not contiguous
This commit is contained in:
parent
9e50c117bc
commit
5b6b257756
@ -106,7 +106,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
for i in range(3):
|
||||
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
weight = torch.cat(tensors, dim=1).T
|
||||
weight = torch.cat(tensors, dim=1).T.contiguous()
|
||||
weight = weight.to(dtype=weights.dtype)
|
||||
weight = weight.to(device=weights.device)
|
||||
|
||||
@ -139,7 +139,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous()
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
@ -159,7 +159,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
||||
[prefix], quantize=config.quantize, dim=1
|
||||
)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous()
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user