diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0c01f56a..f2f0197f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -106,7 +106,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): for i in range(3): tensor = slice_[:, start + i * single_size : stop + i * single_size] tensors.append(tensor) - weight = torch.cat(tensors, dim=1).T + weight = torch.cat(tensors, dim=1).T.contiguous() weight = weight.to(dtype=weights.dtype) weight = weight.to(device=weights.device) @@ -139,7 +139,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.quantize == "gptq": weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) else: - weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous() if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -159,7 +159,7 @@ def load_col(config, prefix: str, weights, bias: bool): [prefix], quantize=config.quantize, dim=1 ) else: - weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous() if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0)