patch quantization

This commit is contained in:
OlivierDehaene 2023-01-31 18:34:47 +01:00
parent ffccb7f9ce
commit 7df81c34db

View File

@ -35,6 +35,8 @@ class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
@ -174,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
@ -193,15 +195,11 @@ class GPTNeoxSharded(GPTNeox):
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def replace_linear(state):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
@ -214,13 +212,11 @@ class GPTNeoxSharded(GPTNeox):
del state.CB
weight.data = state.CxB
return out.view(size_out)
return out
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)