mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix gpt2 tests - some weights were not contiguous
This commit is contained in:
parent
9e50c117bc
commit
5b6b257756
@ -106,7 +106,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
||||||
tensors.append(tensor)
|
tensors.append(tensor)
|
||||||
weight = torch.cat(tensors, dim=1).T
|
weight = torch.cat(tensors, dim=1).T.contiguous()
|
||||||
weight = weight.to(dtype=weights.dtype)
|
weight = weight.to(dtype=weights.dtype)
|
||||||
weight = weight.to(device=weights.device)
|
weight = weight.to(device=weights.device)
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous()
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
@ -159,7 +159,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
[prefix], quantize=config.quantize, dim=1
|
[prefix], quantize=config.quantize, dim=1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous()
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user