From 7df81c34dba8d84b1194ffc6a1c6659fb6e93185 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 31 Jan 2023 18:34:47 +0100 Subject: [PATCH] patch quantization --- server/text_generation/models/gpt_neox.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index 0d38e72e..d901cae3 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -35,6 +35,8 @@ class GPTNeox(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + # Model Forward outputs = self.model.forward( input_ids=input_ids, @@ -174,9 +176,9 @@ class GPTNeoxSharded(GPTNeox): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, @@ -193,15 +195,11 @@ class GPTNeoxSharded(GPTNeox): tensor.CB = None tensor.SCB = None - def replace_linear(state, in_features, out_features): + def replace_linear(state): def linear(input, weight, bias): - size_out = input.size()[:-1] + (out_features,) - input = input.view(-1, in_features) - out = input.new_empty(size_out) out = bnb.matmul( input, weight, - out=out.view(-1, out_features), state=state, threshold=state.threshold, bias=bias, @@ -214,13 +212,11 @@ class GPTNeoxSharded(GPTNeox): del state.CB weight.data = state.CxB - return out.view(size_out) + return out return linear - module.linear = replace_linear( - state, module.in_features, module.out_features - ) + module.linear = replace_linear(state) else: tensor = tensor.to(device)