diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index e3e82992..7708bb4a 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -217,9 +217,7 @@ class BLOOMSharded(BLOOM): return linear - module.linear = replace_linear( - state - ) + module.linear = replace_linear(state) else: tensor = tensor.to(device) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 2f8e0f9a..4722e1d8 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -315,9 +315,7 @@ class GalacticaSharded(Galactica): return linear - module.linear = replace_linear( - state - ) + module.linear = replace_linear(state) else: tensor = tensor.to(device)