diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 63dbedb77..9c1020a5f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -58,7 +58,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize) + weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = (