mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
patch quantization
This commit is contained in:
parent
ffccb7f9ce
commit
7df81c34db
@ -35,6 +35,8 @@ class GPTNeox(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -193,15 +195,11 @@ class GPTNeoxSharded(GPTNeox):
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state, in_features, out_features):
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = input.new_empty(size_out)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
out=out.view(-1, out_features),
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
@ -214,13 +212,11 @@ class GPTNeoxSharded(GPTNeox):
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out.view(size_out)
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(
|
||||
state, module.in_features, module.out_features
|
||||
)
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user