fix(server): fix quantization for sharded models (#45)

This commit is contained in:
OlivierDehaene 2023-01-31 17:40:38 +01:00 committed by GitHub
parent 017a2a8c2f
commit c6e8b9442b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 21 deletions

View File

@ -196,15 +196,11 @@ class BLOOMSharded(BLOOM):
tensor.CB = None tensor.CB = None
tensor.SCB = None tensor.SCB = None
def replace_linear(state, in_features, out_features): def replace_linear(state):
def linear(input, weight, bias): def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul( out = bnb.matmul(
input, input,
weight, weight,
out=out.view(-1, out_features),
state=state, state=state,
threshold=state.threshold, threshold=state.threshold,
bias=bias, bias=bias,
@ -217,13 +213,11 @@ class BLOOMSharded(BLOOM):
del state.CB del state.CB
weight.data = state.CxB weight.data = state.CxB
return out.view(size_out) return out
return linear return linear
module.linear = replace_linear( module.linear = replace_linear(state)
state, module.in_features, module.out_features
)
else: else:
tensor = tensor.to(device) tensor = tensor.to(device)

View File

@ -232,7 +232,6 @@ class GalacticaSharded(Galactica):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[start:stop] tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else: else:
size = slice_.get_shape()[0] size = slice_.get_shape()[0]
block_size = size // world_size block_size = size // world_size
@ -246,7 +245,6 @@ class GalacticaSharded(Galactica):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[:, start:stop] tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else: else:
tensor = slice_[:] tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once. # XXX: Hack for Rowlinear to add the bias only once.
@ -282,7 +280,7 @@ class GalacticaSharded(Galactica):
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor.transpose(1, 0), tensor,
has_fp16_weights=False, has_fp16_weights=False,
requires_grad=False, requires_grad=False,
).to(device) ).to(device)
@ -296,15 +294,11 @@ class GalacticaSharded(Galactica):
tensor.CB = None tensor.CB = None
tensor.SCB = None tensor.SCB = None
def replace_linear(state, in_features, out_features): def replace_linear(state):
def linear(input, weight, bias): def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = input.new_empty(size_out)
out = bnb.matmul( out = bnb.matmul(
input, input,
weight, weight,
out=out.view(-1, out_features),
state=state, state=state,
threshold=state.threshold, threshold=state.threshold,
bias=bias, bias=bias,
@ -317,13 +311,11 @@ class GalacticaSharded(Galactica):
del state.CB del state.CB
weight.data = state.CxB weight.data = state.CxB
return out.view(size_out) return out
return linear return linear
module.linear = replace_linear( module.linear = replace_linear(state)
state, module.in_features, module.out_features
)
else: else:
tensor = tensor.to(device) tensor = tensor.to(device)