diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 45d7cd4c..c5889532 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -255,7 +255,7 @@ class BLOOMSharded(BLOOM): raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor - if name == "word_embeddings.weight": + if "word_embeddings.weight" in name: model.lm_head._parameters["weight"] = tensor def forward(