mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix(server): fix quantization for sharded models (#45)
This commit is contained in:
parent
017a2a8c2f
commit
c6e8b9442b
@ -196,15 +196,11 @@ class BLOOMSharded(BLOOM):
|
|||||||
tensor.CB = None
|
tensor.CB = None
|
||||||
tensor.SCB = None
|
tensor.SCB = None
|
||||||
|
|
||||||
def replace_linear(state, in_features, out_features):
|
def replace_linear(state):
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
|
||||||
input = input.view(-1, in_features)
|
|
||||||
out = input.new_empty(size_out)
|
|
||||||
out = bnb.matmul(
|
out = bnb.matmul(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
out=out.view(-1, out_features),
|
|
||||||
state=state,
|
state=state,
|
||||||
threshold=state.threshold,
|
threshold=state.threshold,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -217,13 +213,11 @@ class BLOOMSharded(BLOOM):
|
|||||||
del state.CB
|
del state.CB
|
||||||
weight.data = state.CxB
|
weight.data = state.CxB
|
||||||
|
|
||||||
return out.view(size_out)
|
return out
|
||||||
|
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(
|
module.linear = replace_linear(state)
|
||||||
state, module.in_features, module.out_features
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
@ -232,7 +232,6 @@ class GalacticaSharded(Galactica):
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
else:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
block_size = size // world_size
|
||||||
@ -246,7 +245,6 @@ class GalacticaSharded(Galactica):
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[:, start:stop]
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
else:
|
||||||
tensor = slice_[:]
|
tensor = slice_[:]
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
@ -282,7 +280,7 @@ class GalacticaSharded(Galactica):
|
|||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor.transpose(1, 0),
|
tensor,
|
||||||
has_fp16_weights=False,
|
has_fp16_weights=False,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
@ -296,15 +294,11 @@ class GalacticaSharded(Galactica):
|
|||||||
tensor.CB = None
|
tensor.CB = None
|
||||||
tensor.SCB = None
|
tensor.SCB = None
|
||||||
|
|
||||||
def replace_linear(state, in_features, out_features):
|
def replace_linear(state):
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
|
||||||
input = input.view(-1, in_features)
|
|
||||||
out = input.new_empty(size_out)
|
|
||||||
out = bnb.matmul(
|
out = bnb.matmul(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
out=out.view(-1, out_features),
|
|
||||||
state=state,
|
state=state,
|
||||||
threshold=state.threshold,
|
threshold=state.threshold,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -317,13 +311,11 @@ class GalacticaSharded(Galactica):
|
|||||||
del state.CB
|
del state.CB
|
||||||
weight.data = state.CxB
|
weight.data = state.CxB
|
||||||
|
|
||||||
return out.view(size_out)
|
return out
|
||||||
|
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(
|
module.linear = replace_linear(state)
|
||||||
state, module.in_features, module.out_features
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user