mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
patch quantization
This commit is contained in:
parent
ffccb7f9ce
commit
7df81c34db
@ -35,6 +35,8 @@ class GPTNeox(CausalLM):
|
|||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Overwrite forward to ignore position_ids"""
|
||||||
|
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -193,15 +195,11 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
tensor.CB = None
|
tensor.CB = None
|
||||||
tensor.SCB = None
|
tensor.SCB = None
|
||||||
|
|
||||||
def replace_linear(state, in_features, out_features):
|
def replace_linear(state):
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
|
||||||
input = input.view(-1, in_features)
|
|
||||||
out = input.new_empty(size_out)
|
|
||||||
out = bnb.matmul(
|
out = bnb.matmul(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
out=out.view(-1, out_features),
|
|
||||||
state=state,
|
state=state,
|
||||||
threshold=state.threshold,
|
threshold=state.threshold,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -214,13 +212,11 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
del state.CB
|
del state.CB
|
||||||
weight.data = state.CxB
|
weight.data = state.CxB
|
||||||
|
|
||||||
return out.view(size_out)
|
return out
|
||||||
|
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(
|
module.linear = replace_linear(state)
|
||||||
state, module.in_features, module.out_features
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user