patch quantization

This commit is contained in:
OlivierDehaene 2023-01-31 18:34:47 +01:00
parent ffccb7f9ce
commit 7df81c34db

View File

@ -35,6 +35,8 @@ class GPTNeox(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -174,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
) )
if ( if (
type(module) type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear] in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor, tensor,
@ -193,15 +195,11 @@ class GPTNeoxSharded(GPTNeox):
tensor.CB = None tensor.CB = None
tensor.SCB = None tensor.SCB = None
def replace_linear(state, in_features, out_features): def replace_linear(state):
def linear(input, weight, bias): def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul( out = bnb.matmul(
input, input,
weight, weight,
out=out.view(-1, out_features),
state=state, state=state,
threshold=state.threshold, threshold=state.threshold,
bias=bias, bias=bias,
@ -214,13 +212,11 @@ class GPTNeoxSharded(GPTNeox):
del state.CB del state.CB
weight.data = state.CxB weight.data = state.CxB
return out.view(size_out) return out
return linear return linear
module.linear = replace_linear( module.linear = replace_linear(state)
state, module.in_features, module.out_features
)
else: else:
tensor = tensor.to(device) tensor = tensor.to(device)